Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
787c1a0b
Commit
787c1a0b
authored
Oct 08, 2019
by
Mohammad Shoeybi
Browse files
moved few common elements between bert and gpt to utils
parent
c882ac61
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
92 deletions
+66
-92
megatron/utils.py
megatron/utils.py
+57
-2
pretrain_bert.py
pretrain_bert.py
+5
-45
pretrain_gpt2.py
pretrain_gpt2.py
+4
-45
No files found.
megatron/utils.py
View file @
787c1a0b
...
@@ -22,6 +22,7 @@ import numpy as np
...
@@ -22,6 +22,7 @@ import numpy as np
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.fp16
import
FP16_Optimizer
from
megatron.fp16
import
FP16_Optimizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
model
from
megatron
import
model
...
@@ -183,13 +184,67 @@ def report_memory(name):
...
@@ -183,13 +184,67 @@ def report_memory(name):
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
print_rank_0
(
string
)
print_rank_0
(
string
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
,
mp_rank
=
None
):
def
initialize_distributed
(
args
):
"""Initialize torch.distributed."""
# Manually set the device ids.
device
=
args
.
rank
%
torch
.
cuda
.
device_count
()
if
args
.
local_rank
is
not
None
:
device
=
args
.
local_rank
torch
.
cuda
.
set_device
(
device
)
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
# Set the model-parallel / data-parallel communicators.
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
def
wrap_model_for_distributed_training
(
model
,
args
):
"""Wrap model for distributed training."""
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
args
.
DDP_type
=
torchDDP
model
=
args
.
DDP_type
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
return
model
elif
args
.
DDP_impl
==
'local'
:
args
.
DDP_type
=
LocalDDP
model
=
args
.
DDP_type
(
model
)
return
model
else
:
print_rank_0
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
exit
()
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
if
seed
is
not
None
and
seed
>
0
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
,
mp_rank
=
None
):
if
release
:
if
release
:
d
=
'release'
d
=
'release'
else
:
else
:
d
=
'iter_{:07d}'
.
format
(
iteration
)
d
=
'iter_{:07d}'
.
format
(
iteration
)
return
os
.
path
.
join
(
checkpoints_path
,
d
,
return
os
.
path
.
join
(
checkpoints_path
,
d
,
'mp_rank_{:02d}'
.
format
(
mpu
.
get_model_parallel_rank
()
if
mp_rank
is
None
else
mp_rank
),
'mp_rank_{:02d}'
.
format
(
mpu
.
get_model_parallel_rank
()
if
mp_rank
is
None
\
else
mp_rank
),
'model_optim_rng.pt'
)
'model_optim_rng.pt'
)
...
...
pretrain_bert.py
View file @
787c1a0b
...
@@ -30,7 +30,6 @@ from megatron.learning_rates import AnnealingLR
...
@@ -30,7 +30,6 @@ from megatron.learning_rates import AnnealingLR
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.model
import
gpt2_get_params_for_weight_decay_optimization
from
megatron.model
import
gpt2_get_params_for_weight_decay_optimization
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron
import
mpu
from
megatron
import
mpu
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedAdam
as
Adam
from
megatron.utils
import
Timers
from
megatron.utils
import
Timers
...
@@ -42,6 +41,10 @@ from megatron.utils import print_params_min_max_norm
...
@@ -42,6 +41,10 @@ from megatron.utils import print_params_min_max_norm
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
enable_adlr_autoresume
from
megatron.utils
import
enable_adlr_autoresume
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
initialize_distributed
from
megatron.utils
import
set_random_seed
from
megatron.utils
import
wrap_model_for_distributed_training
def
get_model
(
args
):
def
get_model
(
args
):
"""Build the model."""
"""Build the model."""
...
@@ -72,18 +75,7 @@ def get_model(args):
...
@@ -72,18 +75,7 @@ def get_model(args):
_module
.
float
()
_module
.
float
()
# Wrap model for distributed training.
# Wrap model for distributed training.
if
args
.
DDP_impl
==
'torch'
:
model
=
wrap_model_for_distributed_training
(
model
,
args
)
i
=
torch
.
cuda
.
current_device
()
args
.
DDP_type
=
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
model
=
args
.
DDP_type
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
elif
args
.
DDP_impl
==
'local'
:
args
.
DDP_type
=
LocalDDP
model
=
args
.
DDP_type
(
model
)
else
:
print_rank_0
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
exit
()
return
model
return
model
...
@@ -474,38 +466,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
...
@@ -474,38 +466,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
return
val_loss
return
val_loss
def
initialize_distributed
(
args
):
"""Initialize torch.distributed."""
# Manually set the device ids.
device
=
args
.
rank
%
torch
.
cuda
.
device_count
()
if
args
.
local_rank
is
not
None
:
device
=
args
.
local_rank
torch
.
cuda
.
set_device
(
device
)
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
# Set the model-parallel / data-parallel communicators.
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
if
seed
is
not
None
and
seed
>
0
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
(
args
):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
...
...
pretrain_gpt2.py
View file @
787c1a0b
...
@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Optimizer
...
@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Optimizer
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
GPT2Model
from
megatron.model
import
GPT2Model
from
megatron.model
import
gpt2_get_params_for_weight_decay_optimization
from
megatron.model
import
gpt2_get_params_for_weight_decay_optimization
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron
import
mpu
from
megatron
import
mpu
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedAdam
as
Adam
from
megatron.utils
import
Timers
from
megatron.utils
import
Timers
...
@@ -41,6 +40,9 @@ from megatron.utils import print_params_min_max_norm
...
@@ -41,6 +40,9 @@ from megatron.utils import print_params_min_max_norm
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
enable_adlr_autoresume
from
megatron.utils
import
enable_adlr_autoresume
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
initialize_distributed
from
megatron.utils
import
set_random_seed
from
megatron.utils
import
wrap_model_for_distributed_training
from
gpt2_data_loader
import
make_gpt2_dataloaders
from
gpt2_data_loader
import
make_gpt2_dataloaders
...
@@ -73,18 +75,7 @@ def get_model(args):
...
@@ -73,18 +75,7 @@ def get_model(args):
model
=
FP16_Module
(
model
)
model
=
FP16_Module
(
model
)
# Wrap model for distributed training.
# Wrap model for distributed training.
if
args
.
DDP_impl
==
'torch'
:
model
=
wrap_model_for_distributed_training
(
model
,
args
)
i
=
torch
.
cuda
.
current_device
()
args
.
DDP_type
=
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
model
=
args
.
DDP_type
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
elif
args
.
DDP_impl
==
'local'
:
args
.
DDP_type
=
LocalDDP
model
=
args
.
DDP_type
(
model
)
else
:
print_rank_0
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
exit
()
return
model
return
model
...
@@ -500,38 +491,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
...
@@ -500,38 +491,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
return
lm_loss
return
lm_loss
def
initialize_distributed
(
args
):
"""Initialize torch.distributed."""
# Manually set the device ids.
device
=
args
.
rank
%
torch
.
cuda
.
device_count
()
if
args
.
local_rank
is
not
None
:
device
=
args
.
local_rank
torch
.
cuda
.
set_device
(
device
)
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
# Set the model-parallel / data-parallel communicators.
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
if
seed
is
not
None
and
seed
>
0
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
(
args
):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment