Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
45504541
Commit
45504541
authored
Dec 02, 2020
by
mohammad
Browse files
Merge branch 'blendable_dataset' into refactor_learning_rate
parents
ff12df6b
98989693
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
25 deletions
+18
-25
megatron/arguments.py
megatron/arguments.py
+9
-4
megatron/checkpointing.py
megatron/checkpointing.py
+5
-13
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+2
-2
megatron/training.py
megatron/training.py
+2
-6
No files found.
megatron/arguments.py
View file @
45504541
...
...
@@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={},
def
_print_args
(
args
):
"""Print arguments."""
if
args
.
rank
==
0
:
print
(
'-------------------- arguments --------------------'
,
flush
=
True
)
print
(
'------------------------ arguments ------------------------'
,
flush
=
True
)
str_list
=
[]
for
arg
in
vars
(
args
):
dots
=
'.'
*
(
32
-
len
(
arg
))
dots
=
'.'
*
(
48
-
len
(
arg
))
str_list
.
append
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)))
for
arg
in
sorted
(
str_list
,
key
=
lambda
x
:
x
.
lower
()):
print
(
arg
,
flush
=
True
)
print
(
'---------------- end of arguments ----------------'
,
flush
=
True
)
print
(
'-------------------- end of arguments ---------------------'
,
flush
=
True
)
def
_check_arg_is_not_none
(
args
,
arg
):
...
...
@@ -401,7 +403,10 @@ def _add_data_args(parser):
group
=
parser
.
add_argument_group
(
title
=
'data and dataloader'
)
group
.
add_argument
(
'--data-path'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Path to combined dataset to split.'
)
help
=
'Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
...
...
megatron/checkpointing.py
View file @
45504541
...
...
@@ -89,8 +89,7 @@ def get_checkpoint_tracker_filename(checkpoints_path):
return
os
.
path
.
join
(
checkpoints_path
,
'latest_checkpointed_iteration.txt'
)
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
consumed_train_samples
=
None
,
consumed_valid_samples
=
None
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
"""Save a model checkpoint."""
args
=
get_args
()
...
...
@@ -104,10 +103,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler,
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
2.0
state_dict
[
'iteration'
]
=
iteration
if
consumed_train_samples
:
state_dict
[
'consumed_train_samples'
]
=
consumed_train_samples
if
consumed_valid_samples
:
state_dict
[
'consumed_valid_samples'
]
=
consumed_valid_samples
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
# Optimizer stuff.
...
...
@@ -219,17 +214,14 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
checkpoint_name
))
sys
.
exit
()
if
'consumed_train_samples'
in
state_dict
:
assert
args
.
consumed_train_samples
==
0
args
.
consumed_train_samples
=
state_dict
[
'consumed_train_samples'
]
if
'consumed_valid_samples'
in
state_dict
:
assert
args
.
consumed_valid_samples
==
0
args
.
consumed_valid_samples
=
state_dict
[
'consumed_valid_samples'
]
# Check arguments.
assert
args
.
consumed_train_samples
==
0
assert
args
.
consumed_valid_samples
==
0
if
'args'
in
state_dict
:
checkpoint_args
=
state_dict
[
'args'
]
check_checkpoint_args
(
checkpoint_args
)
args
.
consumed_train_samples
=
getattr
(
args
,
'consumed_train_samples'
,
0
)
args
.
consumed_valid_samples
=
getattr
(
args
,
'consumed_valid_samples'
,
0
)
else
:
print_rank_0
(
'could not find arguments in the checkpoint ...'
)
...
...
megatron/data/helpers.cpp
View file @
45504541
...
...
@@ -60,7 +60,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
for
(
int64_t
sample_idx
=
0
;
sample_idx
<
size
;
++
sample_idx
)
{
// Determine where the max error in sampling is happening.
double
sample_idx_double
=
std
::
max
(
static_cast
<
double
>
(
sample_idx
),
1.0
);
auto
sample_idx_double
=
std
::
max
(
static_cast
<
double
>
(
sample_idx
),
1.0
);
int64_t
max_error_index
=
0
;
double
max_error
=
weights_ptr
[
0
]
*
sample_idx_double
-
static_cast
<
double
>
(
current_samples
[
0
]);
...
...
@@ -86,7 +86,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
if
(
verbose
)
{
std
::
cout
<<
" > sample ratios:"
<<
std
::
endl
;
for
(
int64_t
dataset_idx
=
0
;
dataset_idx
<
num_datasets
;
++
dataset_idx
)
{
double
ratio
=
static_cast
<
double
>
(
current_samples
[
dataset_idx
])
/
auto
ratio
=
static_cast
<
double
>
(
current_samples
[
dataset_idx
])
/
static_cast
<
double
>
(
size
);
std
::
cout
<<
" dataset "
<<
dataset_idx
<<
", input: "
<<
weights_ptr
[
dataset_idx
]
<<
", achieved: "
<<
ratio
<<
std
::
endl
;
...
...
megatron/training.py
View file @
45504541
...
...
@@ -104,9 +104,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration
,
False
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
consumed_train_samples
=
args
.
consumed_train_samples
,
consumed_valid_samples
=
args
.
consumed_valid_samples
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
if
args
.
do_test
:
# Run on test data.
...
...
@@ -438,9 +436,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Checkpointing
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
consumed_train_samples
=
args
.
consumed_train_samples
,
consumed_valid_samples
=
args
.
consumed_valid_samples
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment