Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
85589322
Commit
85589322
authored
Jan 25, 2021
by
Vijay Korthikanti
Browse files
dataloader_type argument fix + randomsampler fix
parent
e6c7b05e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
5 deletions
+10
-5
megatron/arguments.py
megatron/arguments.py
+4
-1
megatron/data/data_samplers.py
megatron/data/data_samplers.py
+6
-4
No files found.
megatron/arguments.py
View file @
85589322
...
...
@@ -124,6 +124,9 @@ def parse_args(extra_args_provider=None, defaults={},
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
flush
=
True
)
if
args
.
dataloader_type
is
None
:
args
.
dataloader_type
=
'single'
# Consumed tokens.
args
.
consumed_train_samples
=
0
args
.
consumed_valid_samples
=
0
...
...
@@ -365,7 +368,7 @@ def _add_training_args(parser):
group
.
add_argument
(
'--optimizer'
,
type
=
str
,
default
=
'adam'
,
choices
=
[
'adam'
,
'sgd'
],
help
=
'Optimizer function'
)
group
.
add_argument
(
'--dataloader
_
type'
,
type
=
str
,
default
=
'single'
,
group
.
add_argument
(
'--dataloader
-
type'
,
type
=
str
,
default
=
None
,
choices
=
[
'single'
,
'cyclic'
],
help
=
'Single pass vs multiple pass data loader'
)
return
parser
...
...
megatron/data/data_samplers.py
View file @
85589322
...
...
@@ -105,6 +105,8 @@ class MegatronPretrainingRandomSampler:
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_size
*
data_parallel_size
self
.
last_batch_size
=
\
self
.
total_samples
%
self
.
micro_batch_times_data_parallel_size
# Sanity checks.
assert
self
.
total_samples
>
0
,
\
...
...
@@ -119,8 +121,9 @@ class MegatronPretrainingRandomSampler:
return
self
.
total_samples
def
__iter__
(
self
):
self
.
epoch
=
self
.
consumed_samples
//
self
.
total_samples
current_epoch_samples
=
self
.
consumed_samples
%
self
.
total_samples
active_total_samples
=
self
.
total_samples
-
self
.
last_batch_size
self
.
epoch
=
self
.
consumed_samples
//
active_total_samples
current_epoch_samples
=
self
.
consumed_samples
%
active_total_samples
assert
current_epoch_samples
%
self
.
micro_batch_times_data_parallel_size
==
0
# data sharding and random sampling
...
...
@@ -142,4 +145,3 @@ class MegatronPretrainingRandomSampler:
self
.
consumed_samples
+=
self
.
micro_batch_times_data_parallel_size
yield
batch
batch
=
[]
self
.
consumed_samples
+=
self
.
total_samples
%
self
.
micro_batch_times_data_parallel_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment