Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
787c1a0b
Commit
787c1a0b
authored
Oct 08, 2019
by
Mohammad Shoeybi
Browse files
moved few common elements between bert and gpt to utils
parent
c882ac61
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
92 deletions
+66
-92
megatron/utils.py
megatron/utils.py
+57
-2
pretrain_bert.py
pretrain_bert.py
+5
-45
pretrain_gpt2.py
pretrain_gpt2.py
+4
-45
No files found.
megatron/utils.py
View file @
787c1a0b
...
...
@@ -22,6 +22,7 @@ import numpy as np
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.fp16
import
FP16_Optimizer
from
megatron
import
mpu
from
megatron
import
model
...
...
@@ -183,13 +184,67 @@ def report_memory(name):
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
print_rank_0
(
string
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
,
mp_rank
=
None
):
def
initialize_distributed
(
args
):
"""Initialize torch.distributed."""
# Manually set the device ids.
device
=
args
.
rank
%
torch
.
cuda
.
device_count
()
if
args
.
local_rank
is
not
None
:
device
=
args
.
local_rank
torch
.
cuda
.
set_device
(
device
)
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
# Set the model-parallel / data-parallel communicators.
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
def
wrap_model_for_distributed_training
(
model
,
args
):
"""Wrap model for distributed training."""
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
args
.
DDP_type
=
torchDDP
model
=
args
.
DDP_type
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
return
model
elif
args
.
DDP_impl
==
'local'
:
args
.
DDP_type
=
LocalDDP
model
=
args
.
DDP_type
(
model
)
return
model
else
:
print_rank_0
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
exit
()
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
if
seed
is
not
None
and
seed
>
0
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
,
mp_rank
=
None
):
if
release
:
d
=
'release'
else
:
d
=
'iter_{:07d}'
.
format
(
iteration
)
return
os
.
path
.
join
(
checkpoints_path
,
d
,
'mp_rank_{:02d}'
.
format
(
mpu
.
get_model_parallel_rank
()
if
mp_rank
is
None
else
mp_rank
),
'mp_rank_{:02d}'
.
format
(
mpu
.
get_model_parallel_rank
()
if
mp_rank
is
None
\
else
mp_rank
),
'model_optim_rng.pt'
)
...
...
pretrain_bert.py
View file @
787c1a0b
...
...
@@ -30,7 +30,6 @@ from megatron.learning_rates import AnnealingLR
from
megatron.model
import
BertModel
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.model
import
gpt2_get_params_for_weight_decay_optimization
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron
import
mpu
from
apex.optimizers
import
FusedAdam
as
Adam
from
megatron.utils
import
Timers
...
...
@@ -42,6 +41,10 @@ from megatron.utils import print_params_min_max_norm
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
enable_adlr_autoresume
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
initialize_distributed
from
megatron.utils
import
set_random_seed
from
megatron.utils
import
wrap_model_for_distributed_training
def
get_model
(
args
):
"""Build the model."""
...
...
@@ -72,18 +75,7 @@ def get_model(args):
_module
.
float
()
# Wrap model for distributed training.
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
args
.
DDP_type
=
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
model
=
args
.
DDP_type
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
elif
args
.
DDP_impl
==
'local'
:
args
.
DDP_type
=
LocalDDP
model
=
args
.
DDP_type
(
model
)
else
:
print_rank_0
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
exit
()
model
=
wrap_model_for_distributed_training
(
model
,
args
)
return
model
...
...
@@ -474,38 +466,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
return
val_loss
def
initialize_distributed
(
args
):
"""Initialize torch.distributed."""
# Manually set the device ids.
device
=
args
.
rank
%
torch
.
cuda
.
device_count
()
if
args
.
local_rank
is
not
None
:
device
=
args
.
local_rank
torch
.
cuda
.
set_device
(
device
)
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
# Set the model-parallel / data-parallel communicators.
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
if
seed
is
not
None
and
seed
>
0
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
def
get_train_val_test_data
(
args
):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
...
...
pretrain_gpt2.py
View file @
787c1a0b
...
...
@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Optimizer
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
GPT2Model
from
megatron.model
import
gpt2_get_params_for_weight_decay_optimization
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron
import
mpu
from
apex.optimizers
import
FusedAdam
as
Adam
from
megatron.utils
import
Timers
...
...
@@ -41,6 +40,9 @@ from megatron.utils import print_params_min_max_norm
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
enable_adlr_autoresume
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
initialize_distributed
from
megatron.utils
import
set_random_seed
from
megatron.utils
import
wrap_model_for_distributed_training
from
gpt2_data_loader
import
make_gpt2_dataloaders
...
...
@@ -73,18 +75,7 @@ def get_model(args):
model
=
FP16_Module
(
model
)
# Wrap model for distributed training.
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
args
.
DDP_type
=
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
model
=
args
.
DDP_type
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
elif
args
.
DDP_impl
==
'local'
:
args
.
DDP_type
=
LocalDDP
model
=
args
.
DDP_type
(
model
)
else
:
print_rank_0
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
exit
()
model
=
wrap_model_for_distributed_training
(
model
,
args
)
return
model
...
...
@@ -500,38 +491,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
return
lm_loss
def
initialize_distributed
(
args
):
"""Initialize torch.distributed."""
# Manually set the device ids.
device
=
args
.
rank
%
torch
.
cuda
.
device_count
()
if
args
.
local_rank
is
not
None
:
device
=
args
.
local_rank
torch
.
cuda
.
set_device
(
device
)
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
# Set the model-parallel / data-parallel communicators.
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
if
seed
is
not
None
and
seed
>
0
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
def
get_train_val_test_data
(
args
):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment