Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a7ee77ea
Commit
a7ee77ea
authored
Jan 19, 2022
by
Vijay Korthikanti
Browse files
flag for data parallel random initialization
parent
fd8dd9c0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
8 deletions
+21
-8
megatron/arguments.py
megatron/arguments.py
+3
-0
megatron/checkpointing.py
megatron/checkpointing.py
+7
-2
megatron/initialize.py
megatron/initialize.py
+7
-4
megatron/training.py
megatron/training.py
+4
-2
No files found.
megatron/arguments.py
View file @
a7ee77ea
...
@@ -518,6 +518,9 @@ def _add_initialization_args(parser):
...
@@ -518,6 +518,9 @@ def _add_initialization_args(parser):
group
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
,
group
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
,
help
=
'Random seed used for python, numpy, '
help
=
'Random seed used for python, numpy, '
'pytorch, and cuda.'
)
'pytorch, and cuda.'
)
group
.
add_argument
(
'--data-parallel-random-init'
,
action
=
'store_true'
,
help
=
'Enable random initialization of params '
'across data parallel ranks'
)
group
.
add_argument
(
'--init-method-std'
,
type
=
float
,
default
=
0.02
,
group
.
add_argument
(
'--init-method-std'
,
type
=
float
,
default
=
0.02
,
help
=
'Standard deviation of the zero mean normal '
help
=
'Standard deviation of the zero mean normal '
'distribution used for weight initialization.'
)
'distribution used for weight initialization.'
)
...
...
megatron/checkpointing.py
View file @
a7ee77ea
...
@@ -142,6 +142,7 @@ def read_metadata(tracker_filename):
...
@@ -142,6 +142,7 @@ def read_metadata(tracker_filename):
def
get_rng_state
():
def
get_rng_state
():
""" collect rng state across data parallel ranks """
""" collect rng state across data parallel ranks """
args
=
get_args
()
rng_state
=
{
rng_state
=
{
'random_rng_state'
:
random
.
getstate
(),
'random_rng_state'
:
random
.
getstate
(),
'np_rng_state'
:
np
.
random
.
get_state
(),
'np_rng_state'
:
np
.
random
.
get_state
(),
...
@@ -151,7 +152,8 @@ def get_rng_state():
...
@@ -151,7 +152,8 @@ def get_rng_state():
rng_state_list
=
None
rng_state_list
=
None
if
torch
.
distributed
.
is_initialized
()
and
\
if
torch
.
distributed
.
is_initialized
()
and
\
mpu
.
get_data_parallel_world_size
()
>
1
:
mpu
.
get_data_parallel_world_size
()
>
1
and
\
args
.
data_parallel_random_init
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
rng_state_list
=
\
rng_state_list
=
\
[
None
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
[
None
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
...
@@ -407,7 +409,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -407,7 +409,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try
:
try
:
if
'rng_state'
in
state_dict
:
if
'rng_state'
in
state_dict
:
# access rng_state for data parallel rank
# access rng_state for data parallel rank
if
args
.
data_parallel_random_init
:
rng_state
=
state_dict
[
'rng_state'
][
mpu
.
get_data_parallel_rank
()]
rng_state
=
state_dict
[
'rng_state'
][
mpu
.
get_data_parallel_rank
()]
else
:
rng_state
=
state_dict
[
'rng_state'
][
0
]
random
.
setstate
(
rng_state
[
'random_rng_state'
])
random
.
setstate
(
rng_state
[
'random_rng_state'
])
np
.
random
.
set_state
(
rng_state
[
'np_rng_state'
])
np
.
random
.
set_state
(
rng_state
[
'np_rng_state'
])
torch
.
set_rng_state
(
rng_state
[
'torch_rng_state'
])
torch
.
set_rng_state
(
rng_state
[
'torch_rng_state'
])
...
...
megatron/initialize.py
View file @
a7ee77ea
...
@@ -62,7 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -62,7 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Random seeds for reproducibility.
# Random seeds for reproducibility.
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
)
_set_random_seed
(
args
.
seed
,
args
.
data_parallel_random_init
)
# Set pytorch JIT layer fusion options.
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options
()
_set_jit_fusion_options
()
...
@@ -203,11 +203,14 @@ def _init_autoresume():
...
@@ -203,11 +203,14 @@ def _init_autoresume():
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
def
_set_random_seed
(
seed_
):
def
_set_random_seed
(
seed_
,
data_parallel_random_init
=
False
):
"""Set random seed for reproducability."""
"""Set random seed for reproducability."""
if
seed_
is
not
None
and
seed_
>
0
:
if
seed_
is
not
None
and
seed_
>
0
:
# Ensure that different pipeline MP stages and different data parallel ranks get different seeds.
# Ensure that different pipeline MP stages get different seeds.
seed
=
seed_
+
(
100
*
mpu
.
get_pipeline_model_parallel_rank
())
+
(
10
*
mpu
.
get_data_parallel_rank
())
seed
=
seed_
+
(
100
*
mpu
.
get_pipeline_model_parallel_rank
())
# Ensure different data parallel ranks get different seeds
if
data_parallel_random_init
:
seed
=
seed
+
(
10
*
mpu
.
get_data_parallel_rank
())
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
...
...
megatron/training.py
View file @
a7ee77ea
...
@@ -285,6 +285,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
...
@@ -285,6 +285,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
use_contiguous_buffers_in_local_ddp
)
args
.
use_contiguous_buffers_in_local_ddp
)
for
model_module
in
model
]
for
model_module
in
model
]
# broad cast params from data parallel src rank to other data parallel ranks
if
args
.
data_parallel_random_init
:
for
model_module
in
model
:
for
model_module
in
model
:
model_module
.
broadcast_params
()
model_module
.
broadcast_params
()
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment