Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
98989693
Commit
98989693
authored
Dec 02, 2020
by
mohammad
Browse files
addressed Jareds comments
parent
bc56e4a5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
6 deletions
+11
-6
megatron/arguments.py
megatron/arguments.py
+9
-4
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+2
-2
No files found.
megatron/arguments.py
View file @
98989693
...
@@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={},
def
_print_args
(
args
):
def
_print_args
(
args
):
"""Print arguments."""
"""Print arguments."""
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'-------------------- arguments --------------------'
,
flush
=
True
)
print
(
'------------------------ arguments ------------------------'
,
flush
=
True
)
str_list
=
[]
str_list
=
[]
for
arg
in
vars
(
args
):
for
arg
in
vars
(
args
):
dots
=
'.'
*
(
32
-
len
(
arg
))
dots
=
'.'
*
(
48
-
len
(
arg
))
str_list
.
append
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)))
str_list
.
append
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)))
for
arg
in
sorted
(
str_list
,
key
=
lambda
x
:
x
.
lower
()):
for
arg
in
sorted
(
str_list
,
key
=
lambda
x
:
x
.
lower
()):
print
(
arg
,
flush
=
True
)
print
(
arg
,
flush
=
True
)
print
(
'---------------- end of arguments ----------------'
,
flush
=
True
)
print
(
'-------------------- end of arguments ---------------------'
,
flush
=
True
)
def
_check_arg_is_not_none
(
args
,
arg
):
def
_check_arg_is_not_none
(
args
,
arg
):
...
@@ -401,7 +403,10 @@ def _add_data_args(parser):
...
@@ -401,7 +403,10 @@ def _add_data_args(parser):
group
=
parser
.
add_argument_group
(
title
=
'data and dataloader'
)
group
=
parser
.
add_argument_group
(
title
=
'data and dataloader'
)
group
.
add_argument
(
'--data-path'
,
nargs
=
'*'
,
default
=
None
,
group
.
add_argument
(
'--data-path'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Path to combined dataset to split.'
)
help
=
'Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
' validation, and test split. For example the split '
...
...
megatron/data/helpers.cpp
View file @
98989693
...
@@ -60,7 +60,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
...
@@ -60,7 +60,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
for
(
int64_t
sample_idx
=
0
;
sample_idx
<
size
;
++
sample_idx
)
{
for
(
int64_t
sample_idx
=
0
;
sample_idx
<
size
;
++
sample_idx
)
{
// Determine where the max error in sampling is happening.
// Determine where the max error in sampling is happening.
double
sample_idx_double
=
std
::
max
(
static_cast
<
double
>
(
sample_idx
),
1.0
);
auto
sample_idx_double
=
std
::
max
(
static_cast
<
double
>
(
sample_idx
),
1.0
);
int64_t
max_error_index
=
0
;
int64_t
max_error_index
=
0
;
double
max_error
=
weights_ptr
[
0
]
*
sample_idx_double
-
double
max_error
=
weights_ptr
[
0
]
*
sample_idx_double
-
static_cast
<
double
>
(
current_samples
[
0
]);
static_cast
<
double
>
(
current_samples
[
0
]);
...
@@ -86,7 +86,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
...
@@ -86,7 +86,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
if
(
verbose
)
{
if
(
verbose
)
{
std
::
cout
<<
" > sample ratios:"
<<
std
::
endl
;
std
::
cout
<<
" > sample ratios:"
<<
std
::
endl
;
for
(
int64_t
dataset_idx
=
0
;
dataset_idx
<
num_datasets
;
++
dataset_idx
)
{
for
(
int64_t
dataset_idx
=
0
;
dataset_idx
<
num_datasets
;
++
dataset_idx
)
{
double
ratio
=
static_cast
<
double
>
(
current_samples
[
dataset_idx
])
/
auto
ratio
=
static_cast
<
double
>
(
current_samples
[
dataset_idx
])
/
static_cast
<
double
>
(
size
);
static_cast
<
double
>
(
size
);
std
::
cout
<<
" dataset "
<<
dataset_idx
<<
", input: "
<<
std
::
cout
<<
" dataset "
<<
dataset_idx
<<
", input: "
<<
weights_ptr
[
dataset_idx
]
<<
", achieved: "
<<
ratio
<<
std
::
endl
;
weights_ptr
[
dataset_idx
]
<<
", achieved: "
<<
ratio
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment