Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
bf3ce751
Commit
bf3ce751
authored
Apr 01, 2020
by
Mohammad
Browse files
addressed comments from raul, neel, and jared
parent
8600642e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
61 additions
and
61 deletions
+61
-61
megatron/data/samplers.py
megatron/data/samplers.py
+3
-4
megatron/global_vars.py
megatron/global_vars.py
+47
-46
megatron/initialize.py
megatron/initialize.py
+3
-3
megatron/training.py
megatron/training.py
+4
-4
pretrain_gpt2.py
pretrain_gpt2.py
+4
-4
No files found.
megatron/data/samplers.py
View file @
bf3ce751
...
...
@@ -76,12 +76,11 @@ class RandomSampler(data.sampler.Sampler):
class
DistributedBatchSampler
(
data
.
sampler
.
BatchSampler
):
"""
similar to normal implementation of distributed sampler, except
"""Similar to normal implementation of distributed sampler, except
implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch
sampler.
"""
(sequential, random, WeightedRandomSampler, etc.) with this batch
sampler.
"""
def
__init__
(
self
,
sampler
,
batch_size
,
drop_last
,
rank
=-
1
,
world_size
=
2
,
wrap_last
=
False
):
super
(
DistributedBatchSampler
,
self
).
__init__
(
sampler
,
batch_size
,
...
...
megatron/global_vars.py
View file @
bf3ce751
...
...
@@ -141,59 +141,60 @@ def _ensure_var_is_not_initialized(var, name):
assert
var
is
None
,
'{} is already initialized.'
.
format
(
name
)
class
_Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
class
Timers
:
"""Group of timers."""
class
Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
def
__init__
(
self
):
self
.
timers
=
{}
def
__call__
(
self
,
name
):
if
name
not
in
self
.
timers
:
self
.
timers
[
name
]
=
self
.
Timer
(
name
)
self
.
timers
[
name
]
=
_
Timer
(
name
)
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
...
...
@@ -212,7 +213,7 @@ class Timers:
string
=
'time (ms)'
for
name
in
names
:
elapsed_time
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
' | {}: {:.2f}'
.
format
(
name
,
elapsed_time
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
0
:
...
...
megatron/initialize.py
View file @
bf3ce751
...
...
@@ -17,8 +17,8 @@
import
random
import
os
import
numpy
as
np
import
numpy
as
np
import
torch
from
megatron
import
get_adlr_autoresume
...
...
@@ -31,7 +31,7 @@ from megatron.global_vars import set_global_variables
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{}):
"""Set global variables, initialize distributed, and
set autoresume and random seeds."""
# Ma
l
e sure cuda is avaiable.
# Ma
k
e sure cuda is avai
l
able.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
...
...
@@ -45,7 +45,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
# Autoresume.
_init_autoresume
()
# Random seeds for reproduc
a
bility.
# Random seeds for reproduc
i
bility.
args
=
get_args
()
if
args
.
rank
==
0
:
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
...
...
megatron/training.py
View file @
bf3ce751
...
...
@@ -97,7 +97,7 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
print_rank_0
(
'training ...'
)
iteration
=
0
if
args
.
train_iters
>
0
:
if
args
.
do_train
and
args
.
train_iters
>
0
:
if
args
.
do_train
:
iteration
,
_
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
...
...
@@ -151,7 +151,7 @@ def get_model(model_provider_func):
model
=
LocalDDP
(
model
)
return
model
print_rank_0
(
'Unknown DDP implementation specified: {}. '
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
sys
.
exit
()
...
...
@@ -385,8 +385,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag
)
# Autoresume
if
(
iteration
%
args
.
adlr_autoresume
_interval
==
0
)
and
\
args
.
adlr_autoresume
:
if
args
.
adlr_autoresume
and
\
(
iteration
%
args
.
adlr_autoresume
_interval
==
0
)
:
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
pretrain_gpt2.py
View file @
bf3ce751
...
...
@@ -109,13 +109,13 @@ def make_gpt2_dataloaders():
initial_seed
=
args
.
seed
# Build the datasets.
def
build_dataset
_
(
name
):
def
_
build_dataset
(
name
):
return
GPT2Dataset
(
os
.
path
.
join
(
args
.
data_path
,
name
),
args
.
input_data_sizes_file
,
args
.
seq_length
,
args
.
seed
)
train_ds
=
build_dataset
_
(
'train'
)
valid_ds
=
build_dataset
_
(
'valid'
)
test_ds
=
build_dataset
_
(
'test'
)
train_ds
=
_
build_dataset
(
'train'
)
valid_ds
=
_
build_dataset
(
'valid'
)
test_ds
=
_
build_dataset
(
'test'
)
# Dataloaders
train
=
make_data_loader
(
train_ds
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment