Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5942af97
Commit
5942af97
authored
Sep 23, 2022
by
Jared Casper
Browse files
Alias core.parallel_state as mpu and use it throughout code. RIP mpu.
parent
c2ea914f
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
15 deletions
+17
-15
tools/checkpoint_loader_megatron.py
tools/checkpoint_loader_megatron.py
+7
-6
tools/checkpoint_saver_megatron.py
tools/checkpoint_saver_megatron.py
+9
-8
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+1
-1
No files found.
tools/checkpoint_loader_megatron.py
View file @
5942af97
...
...
@@ -30,7 +30,8 @@ def _load_checkpoint(queue, args):
from
megatron.global_vars
import
set_args
,
set_global_variables
from
megatron.checkpointing
import
load_args_from_checkpoint
,
load_checkpoint
from
megatron.model
import
ModelType
,
module
from
megatron
import
mpu
,
fused_kernels
from
megatron.core
import
mpu
from
megatron
import
fused_kernels
except
ModuleNotFoundError
:
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
queue
.
put
(
"exit"
)
...
...
@@ -99,7 +100,7 @@ def _load_checkpoint(queue, args):
nonlocal
consumed_valid_samples
models
=
[]
for
rank
in
range
(
count
):
mpu
.
initializ
e
.
set_tensor_model_parallel_rank
(
rank
)
mpu
.
parallel_stat
e
.
set_tensor_model_parallel_rank
(
rank
)
model_
=
[
model_provider
(
pre_process
,
post_process
).
to
(
dtype
)]
margs
.
consumed_train_samples
=
0
margs
.
consumed_valid_samples
=
0
...
...
@@ -123,8 +124,8 @@ def _load_checkpoint(queue, args):
exit
(
1
)
set_global_variables
(
margs
)
mpu
.
initializ
e
.
set_tensor_model_parallel_world_size
(
margs
.
tensor_model_parallel_size
)
mpu
.
initializ
e
.
set_pipeline_model_parallel_world_size
(
margs
.
pipeline_model_parallel_size
)
mpu
.
parallel_stat
e
.
set_tensor_model_parallel_world_size
(
margs
.
tensor_model_parallel_size
)
mpu
.
parallel_stat
e
.
set_pipeline_model_parallel_world_size
(
margs
.
pipeline_model_parallel_size
)
fused_kernels
.
load
(
margs
)
# Get true (non-padded) vocab size
...
...
@@ -162,7 +163,7 @@ def _load_checkpoint(queue, args):
md
.
make_vocab_size_divisible_by
=
margs
.
make_vocab_size_divisible_by
# Get first pipe stage
mpu
.
initializ
e
.
set_pipeline_model_parallel_rank
(
0
)
mpu
.
parallel_stat
e
.
set_pipeline_model_parallel_rank
(
0
)
post_process
=
pp_size
==
1
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
True
,
post_process
)
...
...
@@ -188,7 +189,7 @@ def _load_checkpoint(queue, args):
total_layer_num
=
0
for
pp_rank
in
range
(
pp_size
):
if
pp_rank
>
0
:
mpu
.
initializ
e
.
set_pipeline_model_parallel_rank
(
pp_rank
)
mpu
.
parallel_stat
e
.
set_pipeline_model_parallel_rank
(
pp_rank
)
post_process
=
pp_rank
==
pp_size
-
1
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
False
,
post_process
)
for
layer_num
in
range
(
len
(
models
[
0
].
language_model
.
encoder
.
layers
)):
...
...
tools/checkpoint_saver_megatron.py
View file @
5942af97
...
...
@@ -34,7 +34,8 @@ def save_checkpoint(queue, args):
from
megatron.global_vars
import
set_global_variables
,
get_args
from
megatron.model
import
ModelType
from
megatron.tokenizer.tokenizer
import
_vocab_size_with_padding
from
megatron
import
mpu
,
fused_kernels
from
megatron
import
fused_kernels
from
megatron.core
import
mpu
except
ModuleNotFoundError
:
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
exit
(
1
)
...
...
@@ -152,10 +153,10 @@ def save_checkpoint(queue, args):
return
models
# fake initializing distributed
mpu
.
initialize
.
set_tensor_model_parallel_world_size
(
args
.
target_tensor_parallel_size
)
mpu
.
initialize
.
set_pipeline_model_parallel_world_size
(
args
.
target_pipeline_parallel_size
)
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
0
)
mpu
.
initialize
.
set_pipeline_model_parallel_rank
(
0
)
mpu
.
set_tensor_model_parallel_world_size
(
args
.
target_tensor_parallel_size
)
mpu
.
set_pipeline_model_parallel_world_size
(
args
.
target_pipeline_parallel_size
)
mpu
.
set_tensor_model_parallel_rank
(
0
)
mpu
.
set_pipeline_model_parallel_rank
(
0
)
fused_kernels
.
load
(
margs
)
# Embeddings
...
...
@@ -197,7 +198,7 @@ def save_checkpoint(queue, args):
out_word_embed
=
torch
.
chunk
(
full_word_embed
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
# Make models for first pipeline stage and fill in embeddings
mpu
.
initialize
.
set_pipeline_model_parallel_rank
(
0
)
mpu
.
set_pipeline_model_parallel_rank
(
0
)
post_process
=
args
.
target_pipeline_parallel_size
==
1
models
=
get_models
(
args
.
target_tensor_parallel_size
,
md
.
params_dtype
,
True
,
post_process
)
for
tp_rank
,
model
in
enumerate
(
models
):
...
...
@@ -211,7 +212,7 @@ def save_checkpoint(queue, args):
for
pp_rank
in
range
(
args
.
target_pipeline_parallel_size
):
# For later pipeline parallel ranks, make the new models
if
pp_rank
>
0
:
mpu
.
initialize
.
set_pipeline_model_parallel_rank
(
pp_rank
)
mpu
.
set_pipeline_model_parallel_rank
(
pp_rank
)
post_process
=
pp_rank
==
args
.
target_pipeline_parallel_size
-
1
models
=
get_models
(
args
.
target_tensor_parallel_size
,
md
.
params_dtype
,
False
,
post_process
)
...
...
@@ -317,6 +318,6 @@ def save_checkpoint(queue, args):
print
(
"ERROR: got some more data but was expecting to be done"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
tp_rank
)
mpu
.
set_tensor_model_parallel_rank
(
tp_rank
)
save_checkpoint
(
md
.
iteration
,
[
models
[
tp_rank
]],
None
,
None
)
print
(
"Done!"
)
tools/run_text_generation_server.py
View file @
5942af97
...
...
@@ -8,7 +8,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
import
socket
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPTModel
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment