Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
52a5f2f2
Commit
52a5f2f2
authored
Oct 20, 2020
by
Deepak Narayanan
Browse files
Intra-layer MP -> Tensor MP, Inter-layer MP -> Pipeline MP
parent
7abd3e90
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
16 deletions
+16
-16
pretrain_ict.py
pretrain_ict.py
+2
-2
tools/merge_mp_partitions.py
tools/merge_mp_partitions.py
+13
-13
tools/preprocess_data.py
tools/preprocess_data.py
+1
-1
No files found.
pretrain_ict.py
View file @
52a5f2f2
...
...
@@ -32,7 +32,7 @@ from megatron.data.realm_dataset_utils import get_ict_batch
def
pretrain_ict_model_provider
():
args
=
get_args
()
assert
args
.
inter_layer
_model_parallel_size
==
1
,
'
inter_layer
_model_parallel_size must be 1!'
assert
args
.
pipeline
_model_parallel_size
==
1
,
'
pipeline
_model_parallel_size must be 1!'
return
general_ict_model_provider
(
False
,
False
)
...
...
@@ -89,7 +89,7 @@ def forward_step(data_iterator, model, input_tensor):
# Forward model.
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
)
local_batch_size
=
query_logits
.
shape
[
0
]
global_batch_size
=
dist
.
get_world_size
()
*
local_batch_size
# recall we assert that
intra_laye
r_model_parallel_size == 1
global_batch_size
=
dist
.
get_world_size
()
*
local_batch_size
# recall we assert that
tenso
r_model_parallel_size == 1
all_query_logits
=
AllgatherFromDataParallelRegion
.
apply
(
query_logits
)
all_block_logits
=
AllgatherFromDataParallelRegion
.
apply
(
block_logits
)
...
...
tools/merge_mp_partitions.py
View file @
52a5f2f2
...
...
@@ -188,18 +188,18 @@ def main():
# Args
args
=
_parse_args
(
extra_args_provider
=
get_mp_merge_args
)
model_type
=
args
.
model_type
orig_
intra_laye
r_model_parallel_size
=
args
.
intra_laye
r_model_parallel_size
args
.
intra_laye
r_model_parallel_size
=
1
orig_
tenso
r_model_parallel_size
=
args
.
tenso
r_model_parallel_size
args
.
tenso
r_model_parallel_size
=
1
tokenizer
=
rebuild_tokenizer
(
args
)
print
(
'
\n
merging model parallel partitions ...'
)
print
(
' > number of partitions: {}'
.
format
(
orig_
intra_laye
r_model_parallel_size
))
print
(
' > number of partitions: {}'
.
format
(
orig_
tenso
r_model_parallel_size
))
print
(
' > checkpoint path: {}'
.
format
(
args
.
load
))
print
(
' > model parameters:'
)
print
(
' number of tokens ................ {} '
.
format
(
tokenizer
.
vocab_size
))
print
(
' number of layers ................ {}'
.
format
(
args
.
num_layers
))
print
(
' hidden si
s
e ..................... {}'
.
format
(
args
.
hidden_size
))
print
(
' hidden si
z
e ..................... {}'
.
format
(
args
.
hidden_size
))
print
(
' number of attention heads ....... {}'
.
format
(
args
.
num_attention_heads
))
print
(
' maximum position embeddings ..... {}'
.
format
(
...
...
@@ -207,18 +207,18 @@ def main():
# Full model.
print
(
'> building the full model ...'
)
mpu
.
initialize
.
set_
intra_laye
r_model_parallel_world_size
(
1
)
mpu
.
initialize
.
set_
intra_laye
r_model_parallel_rank
(
0
)
mpu
.
initialize
.
set_
tenso
r_model_parallel_world_size
(
1
)
mpu
.
initialize
.
set_
tenso
r_model_parallel_rank
(
0
)
merged_model
=
get_model
(
model_type
)
# Build and load partitions.
partitions
=
[]
iteration
=
0
args
.
intra_laye
r_model_parallel_size
=
orig_
intra_laye
r_model_parallel_size
args
.
tenso
r_model_parallel_size
=
orig_
tenso
r_model_parallel_size
tokenizer
=
rebuild_tokenizer
(
args
)
mpu
.
initialize
.
set_
intra_laye
r_model_parallel_world_size
(
args
.
intra_laye
r_model_parallel_size
)
for
rank
in
range
(
args
.
intra_laye
r_model_parallel_size
):
mpu
.
initialize
.
set_
intra_laye
r_model_parallel_rank
(
rank
)
mpu
.
initialize
.
set_
tenso
r_model_parallel_world_size
(
args
.
tenso
r_model_parallel_size
)
for
rank
in
range
(
args
.
tenso
r_model_parallel_size
):
mpu
.
initialize
.
set_
tenso
r_model_parallel_rank
(
rank
)
checkpoint_name
,
iteration
=
get_parallel_checkpoint_name
(
args
.
load
)
print
(
'> loading {} ...'
.
format
(
checkpoint_name
))
model_
=
get_model
(
model_type
)
...
...
@@ -248,7 +248,7 @@ def main():
rank
,
partition_param
.
dtype
,
list
(
partition_param
.
size
())))
# For the non-parallel parameters, simply copy the rank 0 values.
if
not
hasattr
(
merged_param
,
'
intra_laye
r_model_parallel'
):
if
not
hasattr
(
merged_param
,
'
tenso
r_model_parallel'
):
print
(
' none-parallel parameter, simple copy from rank 0'
)
with
torch
.
no_grad
():
merged_param
.
data
.
copy_
(
partitions_param
[
0
].
data
)
...
...
@@ -267,8 +267,8 @@ def main():
# Save the model.
args
.
intra_laye
r_model_parallel_size
=
1
mpu
.
initialize
.
set_
intra_laye
r_model_parallel_rank
(
0
)
args
.
tenso
r_model_parallel_size
=
1
mpu
.
initialize
.
set_
tenso
r_model_parallel_rank
(
0
)
sd
=
{}
sd
[
'model'
]
=
merged_model
.
state_dict_for_save_checkpoint
()
sd
[
'iteration'
]
=
iteration
...
...
tools/preprocess_data.py
View file @
52a5f2f2
...
...
@@ -136,7 +136,7 @@ def get_args():
# some default/dummy values for the tokenizer
args
.
rank
=
0
args
.
make_vocab_size_divisible_by
=
128
args
.
intra_laye
r_model_parallel_size
=
1
args
.
tenso
r_model_parallel_size
=
1
return
args
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment