Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f0a445fa
"torchvision/vscode:/vscode.git/clone" did not exist on "97b53f969ac12fad34a05539a85c58ebae13f027"
Commit
f0a445fa
authored
Nov 27, 2020
by
mohammad
Browse files
added consumed tokens to checkpoints and some refactoring
parent
4311b695
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
77 additions
and
41 deletions
+77
-41
megatron/arguments.py
megatron/arguments.py
+3
-0
megatron/checkpointing.py
megatron/checkpointing.py
+13
-2
megatron/data/data_loaders.py
megatron/data/data_loaders.py
+32
-1
megatron/training.py
megatron/training.py
+29
-13
megatron/utils.py
megatron/utils.py
+0
-25
No files found.
megatron/arguments.py
View file @
f0a445fa
...
...
@@ -72,6 +72,9 @@ def parse_args(extra_args_provider=None, defaults={},
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
flush
=
True
)
# Consumed tokens.
args
.
consumed_train_samples
=
0
args
.
consumed_valid_samples
=
0
# Set input defaults.
for
key
in
defaults
:
...
...
megatron/checkpointing.py
View file @
f0a445fa
...
...
@@ -89,7 +89,8 @@ def get_checkpoint_tracker_filename(checkpoints_path):
return
os
.
path
.
join
(
checkpoints_path
,
'latest_checkpointed_iteration.txt'
)
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
consumed_train_samples
=
None
,
consumed_valid_samples
=
None
):
"""Save a model checkpoint."""
args
=
get_args
()
...
...
@@ -103,6 +104,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
2.0
state_dict
[
'iteration'
]
=
iteration
if
consumed_train_samples
:
state_dict
[
'consumed_train_samples'
]
=
consumed_train_samples
if
consumed_valid_samples
:
state_dict
[
'consumed_valid_samples'
]
=
consumed_valid_samples
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
# Optimizer stuff.
...
...
@@ -213,7 +218,13 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
'iteration from checkpoint {}, exiting'
.
format
(
checkpoint_name
))
sys
.
exit
()
if
'consumed_train_samples'
in
state_dict
:
assert
args
.
consumed_train_samples
==
0
args
.
consumed_train_samples
=
state_dict
[
'consumed_train_samples'
]
if
'consumed_valid_samples'
in
state_dict
:
assert
args
.
consumed_valid_samples
==
0
args
.
consumed_valid_samples
=
state_dict
[
'consumed_valid_samples'
]
# Check arguments.
if
'args'
in
state_dict
:
...
...
megatron/data/
sampl
er.py
→
megatron/data/
data_load
er
s
.py
View file @
f0a445fa
...
...
@@ -13,7 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatorn Sampler."""
"""Dataloaders."""
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
def
build_pretraining_data_loader
(
dataset
,
consumed_samples
):
"""Buld dataloader given an input dataset."""
if
dataset
is
None
:
return
None
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
global_batch_size
=
args
.
batch_size
*
world_size
# Megatron sampler
batch_sampler
=
MegatronPretrainingSampler
(
total_samples
=
len
(
dataset
),
consumed_samples
=
consumed_samples
,
global_batch_size
=
global_batch_size
,
rank
=
mpu
.
get_data_parallel_rank
(),
world_size
=
world_size
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
args
.
num_workers
,
pin_memory
=
True
)
class
MegatronPretrainingSampler
:
...
...
megatron/training.py
View file @
f0a445fa
...
...
@@ -37,7 +37,7 @@ from megatron.model import DistributedDataParallel as LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.
utils
import
make
_data_loader
from
megatron.
data.data_loaders
import
build_pretraining
_data_loader
from
megatron.utils
import
report_memory
...
...
@@ -104,7 +104,9 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration
,
False
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
consumed_train_samples
=
args
.
consumed_train_samples
,
consumed_valid_samples
=
args
.
consumed_valid_samples
)
if
args
.
do_test
:
# Run on test data.
...
...
@@ -224,7 +226,8 @@ def setup_model_and_optimizer(model_provider_func):
while
hasattr
(
unwrapped_model
,
'module'
):
unwrapped_model
=
unwrapped_model
.
module
if
args
.
iteration
==
0
and
hasattr
(
unwrapped_model
,
'init_state_dict_from_bert'
):
if
args
.
iteration
==
0
and
hasattr
(
unwrapped_model
,
'init_state_dict_from_bert'
):
print
(
"Initializing ICT from pretrained BERT model"
,
flush
=
True
)
unwrapped_model
.
init_state_dict_from_bert
()
...
...
@@ -414,6 +417,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer
,
lr_scheduler
)
iteration
+=
1
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
batch_size
# Logging.
loss_scale
=
None
...
...
@@ -433,7 +438,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Checkpointing
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
consumed_train_samples
=
args
.
consumed_train_samples
,
consumed_valid_samples
=
args
.
consumed_valid_samples
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
...
...
@@ -472,6 +479,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args
.
eval_iters
))
# Forward evaluation.
_
,
loss_dict
=
forward_step_func
(
data_iterator
,
model
)
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
batch_size
# Reduce across processes.
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
0.
)
+
\
...
...
@@ -517,11 +526,19 @@ def build_train_valid_test_data_iterators(
(
train_dataloader
,
valid_dataloader
,
test_dataloader
)
=
(
None
,
None
,
None
)
print_rank_0
(
'> building train, validation, and test datasets ...'
)
# Rank and global batch size.
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
global_batch_size
=
args
.
batch_size
*
data_parallel_size
# Backward compatibility, assume fixed batch size.
if
args
.
iteration
>
0
and
args
.
consumed_train_samples
==
0
:
args
.
consumed_train_samples
=
args
.
iteration
*
global_batch_size
if
args
.
iteration
>
0
and
args
.
consumed_valid_samples
==
0
:
args
.
consumed_valid_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
*
global_batch_size
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_model_parallel_rank
()
==
0
:
# Rank, size, and global batch size.
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
global_batch_size
=
args
.
batch_size
*
data_parallel_size
# Number of train/valid/test samples.
train_iters
=
args
.
train_iters
...
...
@@ -540,12 +557,11 @@ def build_train_valid_test_data_iterators(
train_val_test_num_samples
)
# Build dataloders.
comsumed_samples
=
args
.
iteration
*
global_batch_size
train_dataloader
=
make_data_loader
(
train_ds
,
comsumed_samples
)
comsumed_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
*
global_batch_size
valid_dataloader
=
make_data_loader
(
valid_ds
,
comsumed_samples
)
test_dataloader
=
make_data_loader
(
test_ds
,
comsumed_samples
)
train_dataloader
=
build_pretraining_data_loader
(
train_ds
,
args
.
consumed_train_samples
)
valid_dataloader
=
build_pretraining_data_loader
(
valid_ds
,
args
.
consumed_valid_samples
)
test_dataloader
=
build_pretraining_data_loader
(
test_ds
,
0
)
# Flags to know if we need to do training/validation/testing.
do_train
=
train_dataloader
is
not
None
and
args
.
train_iters
>
0
...
...
megatron/utils.py
View file @
f0a445fa
...
...
@@ -24,7 +24,6 @@ from megatron import print_rank_0
from
megatron
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron.checkpointing
import
save_checkpoint
from
megatron.data.sampler
import
MegatronPretrainingSampler
from
megatron.fp16
import
FP16_Optimizer
...
...
@@ -89,30 +88,6 @@ def check_adlr_autoresume_termination(iteration, model,
sys
.
exit
(
0
)
def
make_data_loader
(
dataset
,
consumed_samples
):
"""Buld dataloader given an input dataset."""
if
dataset
is
None
:
return
None
args
=
get_args
()
# Data parallel arguments.
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
world_size
num_workers
=
args
.
num_workers
# Megatron sampler
batch_sampler
=
MegatronPretrainingSampler
(
total_samples
=
len
(
dataset
),
consumed_samples
=
consumed_samples
,
global_batch_size
=
global_batch_size
,
rank
=
rank
,
world_size
=
world_size
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
def
get_ltor_masks_and_position_ids
(
data
,
eod_token
,
reset_position_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment