Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
57c2060f
Commit
57c2060f
authored
Feb 10, 2020
by
Mohammad Shoeybi
Committed by
Jared Casper
Feb 10, 2020
Browse files
Model parallel merger
parent
5df85022
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
276 additions
and
5 deletions
+276
-5
arguments.py
arguments.py
+1
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+2
-0
megatron/model/transformer.py
megatron/model/transformer.py
+1
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+22
-0
megatron/mpu/layers.py
megatron/mpu/layers.py
+7
-4
merge_mp_partitions.py
merge_mp_partitions.py
+243
-0
No files found.
arguments.py
View file @
57c2060f
...
...
@@ -47,7 +47,7 @@ def add_model_config_args(parser):
help
=
'dropout probability for hidden state transformer'
)
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
512
,
help
=
'maximum number of position embeddings to use'
)
group
.
add_argument
(
'--vocab-size'
,
type
=
int
,
default
=
30522
,
group
.
add_argument
(
'--vocab-size'
,
type
=
int
,
default
=
None
,
help
=
'vocab size to use for non-character-level '
'tokenization. This value will only be used when '
'creating a tokenizer'
)
...
...
megatron/model/bert_model.py
View file @
57c2060f
...
...
@@ -83,6 +83,8 @@ class BertLMHead(MegatronModule):
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
.
model_parallel
=
True
self
.
bias
.
partition_dim
=
0
self
.
bias
.
stride
=
1
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
...
...
megatron/model/transformer.py
View file @
57c2060f
...
...
@@ -372,6 +372,7 @@ class ParallelTransformerLayer(MegatronModule):
def
__init__
(
self
,
hyperparameters
,
attention_mask_func
,
layer_number
):
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
apply_residual_connection_post_layernorm
\
=
hyperparameters
[
'apply_residual_connection_post_layernorm'
]
...
...
megatron/mpu/initialize.py
View file @
57c2060f
...
...
@@ -26,6 +26,10 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
# These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE
=
None
_MPU_RANK
=
None
def
initialize_model_parallel
(
model_parallel_size_
):
"""
...
...
@@ -99,13 +103,31 @@ def get_data_parallel_group():
return
_DATA_PARALLEL_GROUP
def
set_model_parallel_world_size
(
world_size
):
"""Set the model parallel size"""
global
_MPU_WORLD_SIZE
_MPU_WORLD_SIZE
=
world_size
def
get_model_parallel_world_size
():
"""Return world size for the model parallel group."""
global
_MPU_WORLD_SIZE
if
_MPU_WORLD_SIZE
is
not
None
:
return
_MPU_WORLD_SIZE
return
torch
.
distributed
.
get_world_size
(
group
=
get_model_parallel_group
())
def
set_model_parallel_rank
(
rank
):
"""Set model parallel rank."""
global
_MPU_RANK
_MPU_RANK
=
rank
def
get_model_parallel_rank
():
"""Return my rank for the model parallel group."""
global
_MPU_RANK
if
_MPU_RANK
is
not
None
:
return
_MPU_RANK
return
torch
.
distributed
.
get_rank
(
group
=
get_model_parallel_group
())
...
...
megatron/mpu/layers.py
View file @
57c2060f
...
...
@@ -46,6 +46,11 @@ def _initialize_affine_weight(weight, output_size, input_size,
Build the master weight on all processes and scatter
the relevant chunk."""
weight
.
model_parallel
=
True
weight
.
partition_dim
=
partition_dim
weight
.
stride
=
stride
# If we only use 1 process for model parallelism, bypass scatter.
world_size
=
get_model_parallel_world_size
()
if
world_size
==
1
:
...
...
@@ -108,7 +113,6 @@ class VocabParallelEmbedding(torch.nn.Module):
# Allocate weights.
self
.
weight
=
Parameter
(
torch
.
Tensor
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
))
self
.
weight
.
model_parallel
=
True
# And initialize.
_initialize_affine_weight
(
self
.
weight
,
self
.
num_embeddings
,
self
.
embedding_dim
,
...
...
@@ -165,7 +169,6 @@ class ParallelEmbedding(torch.nn.Module):
# Allocate weights.
self
.
weight
=
Parameter
(
torch
.
Tensor
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
))
self
.
weight
.
model_parallel
=
True
# And initialize.
_initialize_affine_weight
(
self
.
weight
,
self
.
num_embeddings
,
self
.
embedding_dim
,
...
...
@@ -220,10 +223,11 @@ class ColumnParallelLinear(torch.nn.Module):
# we allocate the transpose.
self
.
weight
=
Parameter
(
torch
.
Tensor
(
self
.
output_size_per_partition
,
self
.
input_size
))
self
.
weight
.
model_parallel
=
True
if
bias
:
self
.
bias
=
Parameter
(
torch
.
Tensor
(
self
.
output_size_per_partition
))
self
.
bias
.
model_parallel
=
True
self
.
bias
.
partition_dim
=
0
self
.
bias
.
stride
=
stride
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
...
...
@@ -294,7 +298,6 @@ class RowParallelLinear(torch.nn.Module):
# we allocate the transpose.
self
.
weight
=
Parameter
(
torch
.
Tensor
(
self
.
output_size
,
self
.
input_size_per_partition
))
self
.
weight
.
model_parallel
=
True
if
bias
:
self
.
bias
=
Parameter
(
torch
.
Tensor
(
self
.
output_size
))
# Always initialize bias to zero.
...
...
merge_mp_partitions.py
0 → 100644
View file @
57c2060f
import
os
import
torch
from
arguments
import
get_args
from
megatron
import
mpu
from
megatron.utils
import
ensure_directory_exists
from
megatron.utils
import
get_checkpoint_name
from
megatron.utils
import
get_checkpoint_tracker_filename
from
megatron.utils
import
vocab_size_with_padding
def
split_into_partitions
(
tensor
,
num_partitions
,
partition_dim
,
stride
):
per_partition_size
=
mpu
.
utils
.
divide
(
tensor
.
size
(
partition_dim
),
num_partitions
)
per_partition_per_stride_size
=
mpu
.
utils
.
divide
(
per_partition_size
,
stride
)
partitions_list
=
torch
.
split
(
tensor
,
per_partition_per_stride_size
,
dim
=
partition_dim
)
partitions
=
[]
for
i
in
range
(
num_partitions
):
partition
=
torch
.
cat
(
partitions_list
[
i
::
num_partitions
],
dim
=
partition_dim
)
partitions
.
append
(
partition
)
return
partitions
def
merge_partitions
(
merged
,
partitions
,
partition_dim
,
stride
):
# Number and size of each partition.
num_partitions
=
len
(
partitions
)
per_partition_size
=
None
for
partition
in
partitions
:
if
per_partition_size
is
None
:
per_partition_size
=
partition
.
size
(
partition_dim
)
else
:
assert
per_partition_size
==
partition
.
size
(
partition_dim
)
def
concat_partitions
(
partitions_
):
with
torch
.
no_grad
():
if
(
per_partition_size
*
num_partitions
)
==
merged
.
size
(
partition_dim
):
torch
.
cat
(
partitions_
,
dim
=
partition_dim
,
out
=
merged
)
else
:
print
(
' ***WARNING*** sizes do not match. Will cut '
'the merged partitions by {} along dimension {} '
'to reduce the size from {} to {} ...'
.
format
(
(
per_partition_size
*
num_partitions
)
-
\
merged
.
size
(
partition_dim
),
partition_dim
,
per_partition_size
*
num_partitions
,
merged
.
size
(
partition_dim
)))
merged_
=
torch
.
cat
(
partitions_
,
dim
=
partition_dim
)
merged_split
=
torch
.
split
(
merged_
,
merged
.
size
(
partition_dim
),
dim
=
partition_dim
)
merged_
=
merged_split
[
0
]
assert
merged_
.
size
(
partition_dim
)
==
merged
.
size
(
partition_dim
)
merged
.
data
.
copy_
(
merged_
.
data
)
# If stride is 1, then do simple concatination.
if
stride
==
1
:
concat_partitions
(
partitions
)
return
# For none unity strides, first split based on stride and then group.
per_partition_per_stride_size
=
mpu
.
utils
.
divide
(
per_partition_size
,
stride
)
# Chunk and build a list.
chunks
=
None
for
i
,
partition
in
enumerate
(
partitions
):
chunk
=
torch
.
split
(
partition
,
per_partition_per_stride_size
,
dim
=
partition_dim
)
if
chunks
is
None
:
chunks
=
[
0
]
*
(
num_partitions
*
len
(
chunk
))
chunks
[
i
::
num_partitions
]
=
chunk
# Concatinate.
concat_partitions
(
chunks
)
return
def
get_model
(
model_type
,
args
):
if
model_type
==
'BERT'
:
from
pretrain_albert
import
model_provider
args
.
tokentype_size
=
2
elif
model_type
==
'GPT'
:
from
pretrain_gpt2
import
model_provider
else
:
raise
Exception
(
'unrecognized model type: {}'
.
format
(
model_type
))
orig_vocab_size
=
args
.
vocab_size
args
.
vocab_size
=
vocab_size_with_padding
(
args
.
vocab_size
,
args
)
model
=
model_provider
(
args
)
model
=
model
.
half
()
args
.
vocab_size
=
orig_vocab_size
return
model
def
get_parallel_checkpoint_name
(
path
):
tracker_filename
=
get_checkpoint_tracker_filename
(
path
)
iteration
=
0
with
open
(
tracker_filename
,
'r'
)
as
f
:
metastring
=
f
.
read
().
strip
()
iteration
=
int
(
metastring
)
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
path
,
iteration
)
return
checkpoint_name
,
iteration
def
test_split_merge
():
print
(
'testing split and merge ...'
)
#[QKV.ROW-COL]
tensor
=
torch
.
FloatTensor
([[
1.11
,
1.12
,
1.13
,
1.14
,
1.15
],
[
1.21
,
1.22
,
1.23
,
1.24
,
1.25
],
[
1.31
,
1.32
,
1.33
,
1.34
,
1.35
],
[
1.41
,
1.42
,
1.43
,
1.44
,
1.45
],
[
2.11
,
2.12
,
2.13
,
2.14
,
2.15
],
[
2.21
,
2.22
,
2.23
,
2.24
,
2.25
],
[
2.31
,
2.32
,
2.33
,
2.34
,
2.35
],
[
2.41
,
2.42
,
2.43
,
2.44
,
2.45
],
[
3.11
,
3.12
,
3.13
,
3.14
,
3.15
],
[
3.21
,
3.22
,
3.23
,
3.24
,
3.25
],
[
3.31
,
3.32
,
3.33
,
3.34
,
3.35
],
[
3.41
,
3.42
,
3.43
,
3.44
,
3.45
]])
num_partitions
=
2
partition_dim
=
0
stride
=
3
partitions
=
split_into_partitions
(
tensor
,
num_partitions
,
partition_dim
,
stride
)
merged
=
torch
.
zeros_like
(
tensor
)
merge_partitions
(
merged
,
partitions
,
partition_dim
,
stride
)
max_error
=
(
merged
-
tensor
).
abs
().
max
()
print
(
' > max error (should be zero): {}'
.
format
(
max_error
))
def
main
(
model_type
):
# Args
args
=
get_args
()
print
(
'
\n
merging model parallel partitions ...'
)
assert
args
.
vocab_size
is
not
None
print
(
' > number of partitions: {}'
.
format
(
args
.
model_parallel_size
))
print
(
' > checkpoint path: {}'
.
format
(
args
.
load
))
print
(
' > model parameters:'
)
print
(
' number of tokens ................ {} '
.
format
(
args
.
vocab_size
))
print
(
' number of layers ................ {}'
.
format
(
args
.
num_layers
))
print
(
' hidden sise ..................... {}'
.
format
(
args
.
hidden_size
))
print
(
' number of attention heads ....... {}'
.
format
(
args
.
num_attention_heads
))
print
(
' maximum position embeddings ..... {}'
.
format
(
args
.
max_position_embeddings
))
# Full model.
print
(
'> building the full model ...'
)
mpu
.
initialize
.
set_model_parallel_world_size
(
1
)
mpu
.
initialize
.
set_model_parallel_rank
(
0
)
merged_model
=
get_model
(
model_type
,
args
)
# Build and load partitions.
partitions
=
[]
iteration
=
0
mpu
.
initialize
.
set_model_parallel_world_size
(
args
.
model_parallel_size
)
for
rank
in
range
(
args
.
model_parallel_size
):
mpu
.
initialize
.
set_model_parallel_rank
(
rank
)
checkpoint_name
,
iteration
=
get_parallel_checkpoint_name
(
args
.
load
)
print
(
'> loading {} ...'
.
format
(
checkpoint_name
))
model_
=
get_model
(
model_type
,
args
)
sd
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
model_
.
load_state_dict
(
sd
[
'model'
])
partitions
.
append
(
model_
)
# Parameter generators so we can loop through them semiltaneouly.
merged_params_gen
=
merged_model
.
named_parameters
()
partitions_params_gen
=
[
partition
.
named_parameters
()
for
partition
in
partitions
]
while
True
:
try
:
# Get the params and check names.
name
,
merged_param
=
next
(
merged_params_gen
)
print
(
' > working on {} ...'
.
format
(
name
))
print
(
' merged type: {}, size: {}'
.
format
(
merged_param
.
dtype
,
list
(
merged_param
.
size
())))
partitions_param
=
[]
for
rank
,
partition_params_gen
in
enumerate
(
partitions_params_gen
):
partition_name
,
partition_param
=
next
(
partition_params_gen
)
assert
partition_name
==
name
partitions_param
.
append
(
partition_param
)
print
(
' partition {} type: {}, size: {}'
.
format
(
rank
,
partition_param
.
dtype
,
list
(
partition_param
.
size
())))
# For the non-parallel parameters, simply copy the rank 0 values.
if
not
hasattr
(
merged_param
,
'model_parallel'
):
print
(
' none-parallel parameter, simple copy from rank 0'
)
with
torch
.
no_grad
():
merged_param
.
data
.
copy_
(
partitions_param
[
0
].
data
)
# For parallel parameters, merge the values
else
:
print
(
' parallel parameter merge with stride {} along '
'dimention {}'
.
format
(
merged_param
.
stride
,
merged_param
.
partition_dim
))
merge_partitions
(
merged_param
,
partitions_param
,
merged_param
.
partition_dim
,
merged_param
.
stride
)
except
StopIteration
:
break
# Save the model.
mpu
.
initialize
.
set_model_parallel_rank
(
0
)
sd
=
{}
sd
[
'model'
]
=
merged_model
.
state_dict_for_save_checkpoint
()
sd
[
'iteration'
]
=
iteration
merged_path
=
os
.
path
.
join
(
args
.
load
,
'merged'
)
checkpoint_name
=
get_checkpoint_name
(
merged_path
,
iteration
)
ensure_directory_exists
(
checkpoint_name
)
print
(
'> saving merged model to {}'
.
format
(
checkpoint_name
))
torch
.
save
(
sd
,
checkpoint_name
)
print
(
'done :-)'
)
if
__name__
==
'__main__'
:
main
(
'BERT'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment