Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5c45db4a
Commit
5c45db4a
authored
Dec 09, 2020
by
Jared Casper
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Initial implementation of pipelined text generation
parent
caa9dca5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
280 additions
and
115 deletions
+280
-115
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+2
-1
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+13
-9
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+250
-103
tools/generate_samples_gpt2.py
tools/generate_samples_gpt2.py
+15
-2
No files found.
megatron/mpu/__init__.py
View file @
5c45db4a
...
...
@@ -34,7 +34,8 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
initialize_model_parallel
...
...
megatron/mpu/initialize.py
View file @
5c45db4a
...
...
@@ -38,6 +38,7 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
_PIPELINE_GLOBAL_RANKS
=
None
def
is_unitialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
...
...
@@ -131,6 +132,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_GLOBAL_RANKS
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
\
'pipeline model parallel group is already initialized'
global
_EMBEDDING_GROUP
...
...
@@ -142,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_GLOBAL_RANKS
=
ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if
len
(
ranks
)
>
1
:
...
...
@@ -265,21 +268,22 @@ def is_pipeline_last_stage():
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to
a
local rank
"""Calculate the global rank corresponding to
the first
local rank
in the tensor model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
local_world_size
=
get_tensor_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_src_rank
():
"""Calculate the global rank corresponding to a local rank
in the pipeline model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
global_world_size
=
torch
.
distributed
.
get_world_size
()
local_world_size
=
get_pipeline_model_parallel_world_size
()
return
global_rank
%
(
global_world_size
//
local_world_size
)
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
...
...
megatron/text_generation_utils.py
View file @
5c45db4a
...
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.training
import
communicate
from
megatron.utils
import
get_ltor_masks_and_position_ids
...
...
@@ -88,14 +89,14 @@ def generate_samples_input_from_file(model):
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_pos
=
0
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'
could not find
`sample-output-file`, setting '
print
(
'`sample-output-file`
not specified
, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
...
...
@@ -105,14 +106,16 @@ def generate_samples_input_from_file(model):
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_tensor_model_parallel_group
())
terminate_runs
=
0
raw_text_len
=
0
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
input_pos
+=
1
if
input_pos
==
input_count
:
raw_text
=
"stop"
raw_text_len
=
len
(
raw_text
)
if
"stop"
in
raw_text
:
terminate_runs
=
1
...
...
@@ -127,38 +130,60 @@ def generate_samples_input_from_file(model):
continue
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
context_length
=
0
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
input_info
=
[
terminate_runs
,
raw_text_len
,
context_length
]
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
torch
.
distributed
.
all_reduce
(
input_info_tensor
,
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
if
terminate_runs
==
1
:
return
# For pipeline parallel we send context tokens to last stage
# so it knows when to start overwriting
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
if
mpu
.
is_pipeline_last_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
context_length
=
input_info_tensor
[
2
].
item
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
pass
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
is_pipeline_first_stage
():
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
fname_out
.
write
(
"
\n
Context:"
)
fname_out
.
write
(
raw_text
)
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
"
\n
"
)
fname_out
.
write
(
"
\n
Context:"
)
fname_out
.
write
(
raw_text
)
raw_text
=
None
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
"
\n
"
)
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_tensor_model_parallel_group
())
raw_text
=
None
context_count
+=
1
...
...
@@ -171,15 +196,17 @@ def generate_samples_interactive(model, print_frequency=24):
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_tensor_model_parallel_group
())
terminate_runs
=
0
raw_text_len
=
0
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
raw_text_len
=
len
(
raw_text
)
if
"stop"
in
raw_text
:
terminate_runs
=
1
...
...
@@ -194,43 +221,70 @@ def generate_samples_interactive(model, print_frequency=24):
continue
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
context_length
=
0
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
input_info
=
[
terminate_runs
,
raw_text_len
,
context_length
]
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
torch
.
distributed
.
all_reduce
(
input_info_tensor
,
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
if
terminate_runs
==
1
:
return
# For pipeline parallel we send context tokens to last stage
# so it knows when to start overwriting
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
if
mpu
.
is_pipeline_last_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
context_length
=
input_info_tensor
[
2
].
item
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
if
counter
%
print_frequency
!=
0
\
or
mpu
.
get_tensor_model_parallel_rank
()
!=
0
\
or
not
mpu
.
is_pipeline_first_stage
():
continue
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
and
\
counter
%
print_frequency
==
0
:
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
if
not
isinstance
(
decode_tokens
,
list
):
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
)
:]
decode_tokens
)[
raw_text
_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
input
(
"
\n
Press Enter to continue >>>"
)
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_tensor_model_parallel_group
())
context_count
+=
1
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
input
(
"
\n
Press any key to continue >>>"
)
def
generate_samples_unconditional
(
model
):
...
...
@@ -247,22 +301,31 @@ def generate_samples_unconditional(model):
for
token_stream
in
get_token_stream
(
model
,
copy
.
deepcopy
(
context_tokens
)):
pass
if
ctr
%
args
.
log_interval
==
0
:
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
start_time
=
time
.
time
()
length
=
len
(
token_stream
)
token_batch
=
token_stream
[
0
].
cpu
().
numpy
().
tolist
()
length_batch
=
token_stream
[
1
].
cpu
().
numpy
().
tolist
()
for
tokens
,
length
in
zip
(
token_batch
,
length_batch
):
tokens
=
tokens
[
1
:
length
-
1
]
text
=
tokenizer
.
detokenize
(
tokens
)
is_finished
=
length
<
args
.
seq_length
-
1
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
yield
datum
ctr
+=
1
if
ctr
>=
num_samples
:
break
if
mpu
.
is_pipeline_last_stage
()
and
\
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
ctr
%
args
.
log_interval
==
0
:
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
start_time
=
time
.
time
()
length
=
len
(
token_stream
)
token_batch
=
token_stream
[
0
].
cpu
().
numpy
().
tolist
()
length_batch
=
token_stream
[
1
].
cpu
().
numpy
().
tolist
()
assert
len
(
length_batch
)
==
args
.
batch_size
for
tokens
,
length
in
zip
(
token_batch
,
length_batch
):
tokens
=
tokens
[
1
:
length
-
1
]
text
=
tokenizer
.
detokenize
(
tokens
)
is_finished
=
length
<
args
.
seq_length
-
1
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
yield
datum
ctr
+=
1
if
ctr
>=
num_samples
:
break
else
:
for
_
in
range
(
args
.
batch_size
):
yield
None
ctr
+=
1
if
ctr
>=
num_samples
:
break
if
ctr
>=
num_samples
:
break
...
...
@@ -273,7 +336,9 @@ def generate_and_write_samples_unconditional(model):
assert
args
.
genfile
is
not
None
with
open
(
args
.
genfile
,
'w'
)
as
f
:
for
datum
in
generate_samples_unconditional
(
model
):
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
if
mpu
.
is_pipeline_last_stage
()
and
\
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
def
pad_batch
(
batch
,
pad_id
,
args
):
...
...
@@ -313,7 +378,10 @@ def get_token_stream(model, context_tokens):
attention_mask
,
position_ids
)
for
tokens
,
lengths
in
batch_token_iterator
:
context_length
+=
1
yield
tokens
[:,
:
context_length
],
lengths
if
tokens
is
not
None
:
yield
tokens
[:,
:
context_length
],
lengths
else
:
yield
None
,
None
def
switch
(
val1
,
val2
,
boolean
):
...
...
@@ -322,6 +390,60 @@ def switch(val1, val2, boolean):
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
,
layer_past
=
None
,
get_key_value
=
None
,
forward_method_parallel_output
=
None
):
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
assert
input_tensor
is
None
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
elif
mpu
.
is_pipeline_last_stage
():
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
output_tensor
,
layer_past
=
output_tensor
if
not
mpu
.
is_pipeline_last_stage
():
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
return
None
if
get_key_value
:
return
output_tensor
,
layer_past
return
output_tensor
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
maxlen
=
None
,
type_ids
=
None
):
...
...
@@ -349,14 +471,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<=
(
maxlen
):
if
args
.
recompute
:
logits
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
type_ids
,
forward_method_parallel_output
=
False
)
logits
=
logits
[:,
context_length
-
1
,
:]
output
=
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
type_ids
,
forward_method_parallel_output
=
False
)
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
logits
=
output
[:,
context_length
-
1
,
:]
else
:
types2use
=
None
if
counter
==
0
:
...
...
@@ -372,41 +495,65 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
logits
,
layer_past
=
model
(
tokens2use
,
positions2use
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
True
,
tokentype_ids
=
types2use
,
forward_method_parallel_output
=
False
)
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
args
.
greedy
:
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
logits
,
layer_past
=
forward_step
(
model
,
tokens2use
,
positions2use
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
True
,
tokentype_ids
=
types2use
,
forward_method_parallel_output
=
False
)
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
mpu
.
is_pipeline_last_stage
():
if
args
.
greedy
:
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
else
:
logits
=
logits
.
float
()
logits
/=
args
.
temperature
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
started
=
context_lengths
<=
context_length
new_tokens
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
]
=
new_tokens
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
is_done
=
is_done
|
done_token
done
=
torch
.
all
(
is_done
)
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
yield
tokens
,
lengths
else
:
logits
=
logits
.
float
()
logits
/=
args
.
temperature
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
print_logits
=
[]
for
p
in
prev
:
print_logits
.
append
([
logits
[
i
,
p
].
item
()
for
i
in
range
(
batch_size
)])
started
=
context_lengths
<=
context_length
tokens
[:,
context_length
]
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
context_length
+=
1
counter
+=
1
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
tokens
[:,
context_length
]
=
new_tokens
yield
tokens
,
None
else
:
yield
None
,
None
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
is_done
=
is_done
|
done_token
done
=
torch
.
all
(
is_done
)
done
=
torch
.
cuda
.
ByteTensor
([
0
])
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
yield
tokens
,
lengths
context_length
+=
1
counter
+=
1
if
done
:
break
tools/generate_samples_gpt2.py
View file @
5c45db4a
...
...
@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPT2Model
from
megatron.model
import
GPT2Model
,
GPT2ModelFirstStage
,
GPT2ModelLastStage
,
GPT2ModelIntermediateStage
from
megatron.training
import
get_model
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_samples_input_from_file
...
...
@@ -36,7 +37,19 @@ def model_provider():
"""Build the model."""
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
model
=
GPT2ModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
GPT2ModelLastStage
(
num_tokentypes
=
0
,
parallel_output
=
False
)
else
:
model
=
GPT2ModelIntermediateStage
(
num_tokentypes
=
0
)
else
:
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment