Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5c45db4a
"src/vscode:/vscode.git/clone" did not exist on "e3bc4aab2ef7b319d2b49e99a25bc2b1b1363bfa"
Commit
5c45db4a
authored
Dec 09, 2020
by
Jared Casper
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Initial implementation of pipelined text generation
parent
caa9dca5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
280 additions
and
115 deletions
+280
-115
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+2
-1
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+13
-9
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+250
-103
tools/generate_samples_gpt2.py
tools/generate_samples_gpt2.py
+15
-2
No files found.
megatron/mpu/__init__.py
View file @
5c45db4a
...
@@ -34,7 +34,8 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
...
@@ -34,7 +34,8 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
initialize_model_parallel
from
.initialize
import
initialize_model_parallel
...
...
megatron/mpu/initialize.py
View file @
5c45db4a
...
@@ -38,6 +38,7 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
...
@@ -38,6 +38,7 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
_PIPELINE_GLOBAL_RANKS
=
None
def
is_unitialized
():
def
is_unitialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
"""Useful for code segments that may be accessed with or without mpu initialization"""
...
@@ -131,6 +132,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -131,6 +132,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the pipeline model-parallel groups and embedding groups
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
# (first and last rank in each pipeline model-parallel group).
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_GLOBAL_RANKS
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
\
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
\
'pipeline model parallel group is already initialized'
'pipeline model parallel group is already initialized'
global
_EMBEDDING_GROUP
global
_EMBEDDING_GROUP
...
@@ -142,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -142,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_GLOBAL_RANKS
=
ranks
# Setup embedding group (to exchange gradients between
# Setup embedding group (to exchange gradients between
# first and last stages).
# first and last stages).
if
len
(
ranks
)
>
1
:
if
len
(
ranks
)
>
1
:
...
@@ -265,21 +268,22 @@ def is_pipeline_last_stage():
...
@@ -265,21 +268,22 @@ def is_pipeline_last_stage():
def
get_tensor_model_parallel_src_rank
():
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to
a
local rank
"""Calculate the global rank corresponding to
the first
local rank
in the tensor model parallel group."""
in the tensor model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
global_rank
=
torch
.
distributed
.
get_rank
()
local_world_size
=
get_tensor_model_parallel_world_size
()
local_world_size
=
get_tensor_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_src_rank
():
def
get_pipeline_model_parallel_first_rank
():
"""Calculate the global rank corresponding to a local rank
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
in the pipeline model parallel group."""
"Pipeline parallel group is not initialized"
global_rank
=
torch
.
distributed
.
get_rank
()
return
_PIPELINE_GLOBAL_RANKS
[
0
]
global_world_size
=
torch
.
distributed
.
get_world_size
()
local_world_size
=
get_pipeline_model_parallel_world_size
()
return
global_rank
%
(
global_world_size
//
local_world_size
)
def
get_data_parallel_world_size
():
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
"""Return world size for the data parallel group."""
...
...
megatron/text_generation_utils.py
View file @
5c45db4a
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.training
import
communicate
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
...
@@ -88,14 +89,14 @@ def generate_samples_input_from_file(model):
...
@@ -88,14 +89,14 @@ def generate_samples_input_from_file(model):
# Read the sample file and open the output file.
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
'sample input file is not provided.'
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_count
=
len
(
all_raw_text
)
input_pos
=
0
input_pos
=
0
if
args
.
sample_output_file
is
None
:
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'
could not find
`sample-output-file`, setting '
print
(
'`sample-output-file`
not specified
, setting '
'it to {}'
.
format
(
sample_output_file
))
'it to {}'
.
format
(
sample_output_file
))
else
:
else
:
sample_output_file
=
args
.
sample_output_file
sample_output_file
=
args
.
sample_output_file
...
@@ -105,14 +106,16 @@ def generate_samples_input_from_file(model):
...
@@ -105,14 +106,16 @@ def generate_samples_input_from_file(model):
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
while
True
:
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_tensor_model_parallel_group
())
terminate_runs
=
0
terminate_runs
=
0
raw_text_len
=
0
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
raw_text
=
all_raw_text
[
input_pos
]
input_pos
+=
1
input_pos
+=
1
if
input_pos
==
input_count
:
if
input_pos
==
input_count
:
raw_text
=
"stop"
raw_text
=
"stop"
raw_text_len
=
len
(
raw_text
)
if
"stop"
in
raw_text
:
if
"stop"
in
raw_text
:
terminate_runs
=
1
terminate_runs
=
1
...
@@ -127,38 +130,60 @@ def generate_samples_input_from_file(model):
...
@@ -127,38 +130,60 @@ def generate_samples_input_from_file(model):
continue
continue
else
:
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
context_length
=
0
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
input_info
=
[
terminate_runs
,
raw_text_len
,
context_length
]
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
mpu
.
get_tensor_model_parallel_src_rank
(),
torch
.
distributed
.
all_reduce
(
input_info_tensor
,
group
=
mpu
.
get_tensor_model_parallel_group
())
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
if
terminate_runs
==
1
:
if
terminate_runs
==
1
:
return
return
# For pipeline parallel we send context tokens to last stage
# so it knows when to start overwriting
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
if
mpu
.
is_pipeline_last_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
context_length
=
input_info_tensor
[
2
].
item
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
decode_tokens
,
_
=
decode_tokens
pass
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
os
.
system
(
'clear'
)
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
fname_out
.
write
(
"
\n
Context:"
)
fname_out
.
write
(
"
\n
Context:"
)
fname_out
.
write
(
raw_text
)
fname_out
.
write
(
raw_text
)
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
"
\n
"
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_tensor_model_parallel_group
())
context_count
+=
1
context_count
+=
1
...
@@ -171,15 +196,17 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -171,15 +196,17 @@ def generate_samples_interactive(model, print_frequency=24):
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
while
True
:
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_tensor_model_parallel_group
())
terminate_runs
=
0
terminate_runs
=
0
raw_text_len
=
0
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
os
.
system
(
'clear'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
while
not
raw_text
:
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
raw_text_len
=
len
(
raw_text
)
if
"stop"
in
raw_text
:
if
"stop"
in
raw_text
:
terminate_runs
=
1
terminate_runs
=
1
...
@@ -194,43 +221,70 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -194,43 +221,70 @@ def generate_samples_interactive(model, print_frequency=24):
continue
continue
else
:
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
context_length
=
0
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
input_info
=
[
terminate_runs
,
raw_text_len
,
context_length
]
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
mpu
.
get_tensor_model_parallel_src_rank
(),
torch
.
distributed
.
all_reduce
(
input_info_tensor
,
group
=
mpu
.
get_tensor_model_parallel_group
())
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
if
terminate_runs
==
1
:
if
terminate_runs
==
1
:
return
return
# For pipeline parallel we send context tokens to last stage
# so it knows when to start overwriting
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
if
mpu
.
is_pipeline_last_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
context_length
=
input_info_tensor
[
2
].
item
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
decode_tokens
,
_
=
decode_tokens
if
counter
%
print_frequency
!=
0
\
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
or
mpu
.
get_tensor_model_parallel_rank
()
!=
0
\
or
not
mpu
.
is_pipeline_first_stage
():
continue
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
and
\
counter
%
print_frequency
==
0
:
os
.
system
(
'clear'
)
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
)
:]
decode_tokens
)[
raw_text
_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
if
not
isinstance
(
decode_tokens
,
list
):
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
)
:]
decode_tokens
)[
raw_text
_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
input
(
"
\n
Press Enter to continue >>>"
)
raw_text
=
None
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_tensor_model_parallel_group
())
context_count
+=
1
context_count
+=
1
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
input
(
"
\n
Press any key to continue >>>"
)
def
generate_samples_unconditional
(
model
):
def
generate_samples_unconditional
(
model
):
...
@@ -247,6 +301,8 @@ def generate_samples_unconditional(model):
...
@@ -247,6 +301,8 @@ def generate_samples_unconditional(model):
for
token_stream
in
get_token_stream
(
model
,
for
token_stream
in
get_token_stream
(
model
,
copy
.
deepcopy
(
context_tokens
)):
copy
.
deepcopy
(
context_tokens
)):
pass
pass
if
mpu
.
is_pipeline_last_stage
()
and
\
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
ctr
%
args
.
log_interval
==
0
:
if
ctr
%
args
.
log_interval
==
0
:
print
(
'Avg s/batch:'
,
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
...
@@ -254,6 +310,7 @@ def generate_samples_unconditional(model):
...
@@ -254,6 +310,7 @@ def generate_samples_unconditional(model):
length
=
len
(
token_stream
)
length
=
len
(
token_stream
)
token_batch
=
token_stream
[
0
].
cpu
().
numpy
().
tolist
()
token_batch
=
token_stream
[
0
].
cpu
().
numpy
().
tolist
()
length_batch
=
token_stream
[
1
].
cpu
().
numpy
().
tolist
()
length_batch
=
token_stream
[
1
].
cpu
().
numpy
().
tolist
()
assert
len
(
length_batch
)
==
args
.
batch_size
for
tokens
,
length
in
zip
(
token_batch
,
length_batch
):
for
tokens
,
length
in
zip
(
token_batch
,
length_batch
):
tokens
=
tokens
[
1
:
length
-
1
]
tokens
=
tokens
[
1
:
length
-
1
]
text
=
tokenizer
.
detokenize
(
tokens
)
text
=
tokenizer
.
detokenize
(
tokens
)
...
@@ -263,6 +320,12 @@ def generate_samples_unconditional(model):
...
@@ -263,6 +320,12 @@ def generate_samples_unconditional(model):
ctr
+=
1
ctr
+=
1
if
ctr
>=
num_samples
:
if
ctr
>=
num_samples
:
break
break
else
:
for
_
in
range
(
args
.
batch_size
):
yield
None
ctr
+=
1
if
ctr
>=
num_samples
:
break
if
ctr
>=
num_samples
:
if
ctr
>=
num_samples
:
break
break
...
@@ -273,6 +336,8 @@ def generate_and_write_samples_unconditional(model):
...
@@ -273,6 +336,8 @@ def generate_and_write_samples_unconditional(model):
assert
args
.
genfile
is
not
None
assert
args
.
genfile
is
not
None
with
open
(
args
.
genfile
,
'w'
)
as
f
:
with
open
(
args
.
genfile
,
'w'
)
as
f
:
for
datum
in
generate_samples_unconditional
(
model
):
for
datum
in
generate_samples_unconditional
(
model
):
if
mpu
.
is_pipeline_last_stage
()
and
\
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
...
@@ -313,7 +378,10 @@ def get_token_stream(model, context_tokens):
...
@@ -313,7 +378,10 @@ def get_token_stream(model, context_tokens):
attention_mask
,
position_ids
)
attention_mask
,
position_ids
)
for
tokens
,
lengths
in
batch_token_iterator
:
for
tokens
,
lengths
in
batch_token_iterator
:
context_length
+=
1
context_length
+=
1
if
tokens
is
not
None
:
yield
tokens
[:,
:
context_length
],
lengths
yield
tokens
[:,
:
context_length
],
lengths
else
:
yield
None
,
None
def
switch
(
val1
,
val2
,
boolean
):
def
switch
(
val1
,
val2
,
boolean
):
...
@@ -322,6 +390,60 @@ def switch(val1, val2, boolean):
...
@@ -322,6 +390,60 @@ def switch(val1, val2, boolean):
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
,
layer_past
=
None
,
get_key_value
=
None
,
forward_method_parallel_output
=
None
):
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
assert
input_tensor
is
None
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
elif
mpu
.
is_pipeline_last_stage
():
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
output_tensor
,
layer_past
=
output_tensor
if
not
mpu
.
is_pipeline_last_stage
():
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
return
None
if
get_key_value
:
return
output_tensor
,
layer_past
return
output_tensor
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
attention_mask
,
position_ids
,
maxlen
=
None
,
type_ids
=
None
):
maxlen
=
None
,
type_ids
=
None
):
...
@@ -349,14 +471,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -349,14 +471,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<=
(
maxlen
):
while
context_length
<=
(
maxlen
):
if
args
.
recompute
:
if
args
.
recompute
:
logits
=
model
(
tokens
,
output
=
forward_step
(
model
,
tokens
,
position_ids
,
position_ids
,
attention_mask
,
attention_mask
,
tokentype_ids
=
type_ids
,
tokentype_ids
=
type_ids
,
forward_method_parallel_output
=
False
)
forward_method_parallel_output
=
False
)
logits
=
logits
[:,
context_length
-
1
,
:]
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
logits
=
output
[:,
context_length
-
1
,
:]
else
:
else
:
types2use
=
None
types2use
=
None
if
counter
==
0
:
if
counter
==
0
:
...
@@ -372,15 +495,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -372,15 +495,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if
type_ids
is
not
None
:
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
batch_size
,
-
1
)
logits
,
layer_past
=
model
(
tokens2use
,
logits
,
layer_past
=
forward_step
(
model
,
tokens2use
,
positions2use
,
positions2use
,
attention_mask
,
attention_mask
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
get_key_value
=
True
,
get_key_value
=
True
,
tokentype_ids
=
types2use
,
tokentype_ids
=
types2use
,
forward_method_parallel_output
=
False
)
forward_method_parallel_output
=
False
)
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
mpu
.
is_pipeline_last_stage
():
if
args
.
greedy
:
if
args
.
greedy
:
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
else
:
else
:
...
@@ -391,22 +517,43 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -391,22 +517,43 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
print_logits
=
[]
for
p
in
prev
:
print_logits
.
append
([
logits
[
i
,
p
].
item
()
for
i
in
range
(
batch_size
)])
started
=
context_lengths
<=
context_length
started
=
context_lengths
<=
context_length
tokens
[:,
context_length
]
=
switch
(
new_tokens
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
context_length
+=
1
tokens
[:,
context_length
]
=
new_tokens
counter
+=
1
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
is_done
=
is_done
|
done_token
is_done
=
is_done
|
done_token
done
=
torch
.
all
(
is_done
)
done
=
torch
.
all
(
is_done
)
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
yield
tokens
,
lengths
yield
tokens
,
lengths
else
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
tokens
[:,
context_length
]
=
new_tokens
yield
tokens
,
None
else
:
yield
None
,
None
done
=
torch
.
cuda
.
ByteTensor
([
0
])
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
context_length
+=
1
counter
+=
1
if
done
:
if
done
:
break
break
tools/generate_samples_gpt2.py
View file @
5c45db4a
...
@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
...
@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPT2Model
from
megatron.model
import
GPT2Model
,
GPT2ModelFirstStage
,
GPT2ModelLastStage
,
GPT2ModelIntermediateStage
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_samples_input_from_file
from
megatron.text_generation_utils
import
generate_samples_input_from_file
...
@@ -36,6 +37,18 @@ def model_provider():
...
@@ -36,6 +37,18 @@ def model_provider():
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building GPT2 model ...'
)
print_rank_0
(
'building GPT2 model ...'
)
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
model
=
GPT2ModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
GPT2ModelLastStage
(
num_tokentypes
=
0
,
parallel_output
=
False
)
else
:
model
=
GPT2ModelIntermediateStage
(
num_tokentypes
=
0
)
else
:
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
return
model
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment