Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5c45db4a
Commit
5c45db4a
authored
Dec 09, 2020
by
Jared Casper
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Initial implementation of pipelined text generation
parent
caa9dca5
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
280 additions
and
115 deletions
+280
-115
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+2
-1
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+13
-9
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+250
-103
tools/generate_samples_gpt2.py
tools/generate_samples_gpt2.py
+15
-2
No files found.
megatron/mpu/__init__.py
View file @
5c45db4a
...
...
@@ -34,7 +34,8 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
initialize_model_parallel
...
...
megatron/mpu/initialize.py
View file @
5c45db4a
...
...
@@ -38,6 +38,7 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
_PIPELINE_GLOBAL_RANKS
=
None
def
is_unitialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
...
...
@@ -131,6 +132,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_GLOBAL_RANKS
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
\
'pipeline model parallel group is already initialized'
global
_EMBEDDING_GROUP
...
...
@@ -142,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_GLOBAL_RANKS
=
ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if
len
(
ranks
)
>
1
:
...
...
@@ -265,21 +268,22 @@ def is_pipeline_last_stage():
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to
a
local rank
"""Calculate the global rank corresponding to
the first
local rank
in the tensor model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
local_world_size
=
get_tensor_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_src_rank
():
"""Calculate the global rank corresponding to a local rank
in the pipeline model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
global_world_size
=
torch
.
distributed
.
get_world_size
()
local_world_size
=
get_pipeline_model_parallel_world_size
()
return
global_rank
%
(
global_world_size
//
local_world_size
)
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
...
...
megatron/text_generation_utils.py
View file @
5c45db4a
This diff is collapsed.
Click to expand it.
tools/generate_samples_gpt2.py
View file @
5c45db4a
...
...
@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPT2Model
from
megatron.model
import
GPT2Model
,
GPT2ModelFirstStage
,
GPT2ModelLastStage
,
GPT2ModelIntermediateStage
from
megatron.training
import
get_model
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_samples_input_from_file
...
...
@@ -36,7 +37,19 @@ def model_provider():
"""Build the model."""
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
model
=
GPT2ModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
GPT2ModelLastStage
(
num_tokentypes
=
0
,
parallel_output
=
False
)
else
:
model
=
GPT2ModelIntermediateStage
(
num_tokentypes
=
0
)
else
:
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment