Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ce29d4d5
Commit
ce29d4d5
authored
Apr 02, 2020
by
Mohammad
Browse files
working on refactoring text generation
parent
a0bcee94
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
189 additions
and
225 deletions
+189
-225
generate_samples.py
generate_samples.py
+174
-198
megatron/arguments.py
megatron/arguments.py
+1
-23
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-2
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+6
-2
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+7
-0
No files found.
generate_samples.py
View file @
ce29d4d5
...
...
@@ -39,117 +39,87 @@ from megatron.model import GPT2Model
from
megatron.model
import
DistributedDataParallel
as
DDP
from
megatron
import
print_rank_0
def
get_model
(
args
):
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
attention_dropout_prob
=
args
.
attention_dropout
,
output_dropout_prob
=
args
.
hidden_dropout
,
max_sequence_length
=
args
.
max_position_embeddings
,
checkpoint_activations
=
args
.
checkpoint_activations
,
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
parallel_output
=
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' > number of parameters on model parallel rank {}: {}'
.
format
(
mpu
.
get_model_parallel_rank
(),
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])),
flush
=
True
)
# GPU allocation.
model
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
if
args
.
fp16
:
model
=
FP16_Module
(
model
)
# Wrap model for distributed training.
model
=
DDP
(
model
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
return
model
def
setup_model
(
args
):
"""Setup model and optimizer."""
model
=
get_model
(
args
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
,
args
)
return
model
def
get_batch
(
context_tokens
,
args
):
tokens
=
context_tokens
tokens
=
tokens
.
view
(
args
.
batch_size
,
-
1
).
contiguous
()
device
=
args
.
device
tokens
=
tokens
.
to
(
device
)
def
get_batch
(
context_tokens
):
"""Generate batch from context tokens."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Get the masks and postition ids.
attention_mask
,
loss_mask
,
position_ids
=
get_ltor_masks_and_position_ids
(
# Move to GPU.
tokens
=
context_tokens
.
view
(
args
.
batch_size
,
-
1
)..
contiguous
().
cuda
()
# Get the attention mask and postition ids.
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
args
.
eod_token
,
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
False
)
# Fp16 conversion.
if
args
.
fp16
:
attention_mask
=
attention_mask
.
half
()
args
.
eod_mask_loss
,
args
.
fp16
)
return
tokens
,
attention_mask
,
position_ids
def
top_k_logits
(
logits
,
top_k
=
0
,
top_p
=
0.0
,
filter_value
=-
float
(
'Inf'
)):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
""" This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if
top_k
>
0
:
# Remove all tokens with a probability less than the last token of the top-k
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
if
top_p
>
0.0
:
#convert to 1D
# logits=logits.view(logits.size()[1]).contiguous()
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
,
dim
=-
1
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Cconvert to 1D
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
,
dim
=-
1
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove
=
cumulative_probs
>
top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove
[...,
1
:]
\
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
for
i
in
range
(
sorted_indices
.
size
(
0
)):
indices_to_remove
=
sorted_indices
[
i
][
sorted_indices_to_remove
[
i
]]
logits
[
i
][
indices_to_remove
]
=
filter_value
#going back to 2D
# logits=logits.view(1, -1).contiguous()
return
logits
def
generate_samples_input_from_file
(
model
,
tokenizer
,
args
):
if
args
.
sample_input_file
==
""
:
if
mpu
.
get_model_parallel_rank
()
==
0
:
print
(
"args.sample_input_file CAN NOT BE empty!
\n
"
)
return
def
generate_samples_input_from_file
(
model
):
"""XXX"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
get_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_pos
=
0
if
args
.
sample_output_file
==
""
:
print
(
"Argument:
sample
-
output
-
file
can't be empty, setting it to
\n
"
)
print
(
"
\t
args.
sample
_in
put
_
file
.out"
)
args
.
sample_output_file
=
args
.
sample_input_file
+
".out"
fname_out
=
open
(
args
.
sample_output_file
,
"w+"
)
if
args
.
sample_output_file
is
None
:
sample
_
output
_
file
=
args
.
sample_input_file
+
".out"
print
(
'could not find `
sample
-out
put
-
file
`, setting '
'it to {}'
.
formatsample_output_file
())
fname_out
=
open
(
sample_output_file
,
"w+"
)
context_count
=
0
model
.
eval
()
...
...
@@ -167,46 +137,44 @@ def generate_samples_input_from_file(model, tokenizer, args):
if
"stop"
in
raw_text
:
terminate_runs
=
1
else
:
context_tokens
=
tokenizer
.
EncodeAsIds
(
raw_text
).
tokenization
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
context_length
=
len
(
context_tokens
)
if
context_length
>=
args
.
seq_length
//
2
:
if
context_length
>=
(
args
.
seq_length
//
2
)
:
print
(
"
\n
Context length"
,
context_length
,
\
"
\n
Please give smaller context (half of the sequence length)!"
)
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
context_tokens
=
tokenizer
.
EncodeAsIds
(
"EMPTY TEXT"
)
.
tokenization
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
if
terminate_runs
==
1
:
return
start_time
=
time
.
time
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
]
,
tokenizer
,
args
)
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
if
mpu
.
get_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
DecodeIds
(
decode_tokens
)[
len
(
raw_text
):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
fname_out
.
write
(
"
\n
Context:"
)
fname_out
.
write
(
raw_text
)
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
trim_decode_tokens
)
#fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
...
...
@@ -214,9 +182,11 @@ def generate_samples_input_from_file(model, tokenizer, args):
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
context_count
+=
1
def
generate_samples_interactive
(
model
,
tokenizer
,
args
):
print_frequency
=
24
def
generate_samples_interactive
(
model
,
print_frequency
=
24
):
"""XXX"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_count
=
0
model
.
eval
()
...
...
@@ -235,79 +205,81 @@ def generate_samples_interactive(model, tokenizer, args):
if
"stop"
in
raw_text
:
terminate_runs
=
1
else
:
context_tokens
=
tokenizer
.
EncodeAsIds
(
raw_text
).
tokenization
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
context_length
=
len
(
context_tokens
)
if
context_length
>=
args
.
seq_length
//
2
:
if
context_length
>=
(
args
.
seq_length
//
2
)
:
print
(
"
\n
Context length"
,
context_length
,
\
"
\n
Please give smaller context (half of the sequence length)!"
)
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
context_tokens
=
tokenizer
.
EncodeAsIds
(
"EMPTY TEXT"
)
.
tokenization
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
if
terminate_runs
==
1
:
return
start_time
=
time
.
time
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
]
,
tokenizer
,
args
)
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
if
mpu
.
get_model_parallel_rank
()
==
0
and
counter
%
print_frequency
==
0
:
if
mpu
.
get_model_parallel_rank
()
==
0
and
\
counter
%
print_frequency
==
0
:
os
.
system
(
'clear'
)
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
DecodeIds
(
decode_tokens
)[
len
(
raw_text
):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
get_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
DecodeIds
(
decode_tokens
)[
len
(
raw_text
):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
context_count
+=
1
if
mpu
.
get_model_parallel_rank
()
==
0
:
input
(
"
\n
Press any key to continue >>>"
)
def
generate_samples_unconditional
(
model
,
tokenizer
,
args
):
def
generate_samples_unconditional
(
model
):
"""XXX"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
num_samples
=
args
.
num_samples
context_tokens
=
[[
tokenizer
.
get_command
(
'pad'
).
Id
]
for
_
in
range
(
args
.
batch_size
)]
context_tokens
=
[[
tokenizer
.
eod
]
for
_
in
range
(
args
.
batch_size
)]
samples
=
[]
# with open(args.genfile, 'w') as f:
ctr
=
0
while
True
:
start_time
=
time
.
time
()
for
token_stream
in
get_token_stream
(
model
,
copy
.
deepcopy
(
context_tokens
),
tokenizer
,
args
):
for
token_stream
in
get_token_stream
(
model
,
copy
.
deepcopy
(
context_tokens
)):
pass
# token_stream = list(get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args))
if
ctr
%
args
.
log_interval
==
0
:
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
start_time
=
time
.
time
()
length
=
len
(
token_stream
)
token_batch
=
token_stream
[
0
].
cpu
().
numpy
().
tolist
()
length_batch
=
token_stream
[
1
].
cpu
().
numpy
().
tolist
()
for
tokens
,
length
in
zip
(
token_batch
,
length_batch
):
tokens
=
tokens
[
1
:
length
-
1
]
text
=
tokenizer
.
DecodeIds
(
tokens
)
text
=
tokenizer
.
detokenize
(
tokens
)
is_finished
=
length
<
args
.
seq_length
-
1
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
yield
datum
...
...
@@ -317,35 +289,42 @@ def generate_samples_unconditional(model, tokenizer, args):
if
ctr
>=
num_samples
:
break
def
write_and_generate_samples_unconditional
(
model
,
tokenizer
,
args
):
def
write_and_generate_samples_unconditional
(
model
):
args
=
get_args
()
assert
args
.
genfile
is
not
None
with
open
(
args
.
genfile
,
'w'
)
as
f
:
for
datum
in
generate_samples_unconditional
(
model
,
tokenizer
,
args
):
for
datum
in
generate_samples_unconditional
(
model
):
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
def
pad_batch
(
batch
,
tokenizer
,
args
):
pad_id
=
tokenizer
.
get_command
(
'pad'
).
I
d
pad_id
=
tokenizer
.
eo
d
context_lengths
=
[]
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
if
context_length
<
args
.
seq_length
:
tokens
.
extend
([
pad_id
]
*
(
args
.
seq_length
-
context_length
))
tokens
.
extend
([
pad_id
]
*
(
args
.
seq_length
-
context_length
))
context_lengths
.
append
(
context_length
)
return
batch
,
context_lengths
def
get_token_stream
(
model
,
context_tokens
,
tokenizer
,
args
):
pad_id
=
tokenizer
.
get_command
(
'pad'
).
Id
# context_length = len(context_tokens)
# if context_length < args.seq_length:
# context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length)
def
get_token_stream
(
model
,
context_tokens
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
pad_id
=
tokenizer
.
eod
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
# context_length_tensor = torch.cuda.LongTensor([context_length])
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
,
args
)
...
...
@@ -355,7 +334,9 @@ def get_token_stream(model, context_tokens, tokenizer, args):
layer_past
=
None
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
,
tokenizer
,
args
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
)
for
tokens
,
lengths
in
batch_token_iterator
:
context_length
+=
1
yield
tokens
[:,
:
context_length
],
lengths
...
...
@@ -365,14 +346,14 @@ def switch(val1, val2, boolean):
boolean
=
boolean
.
type_as
(
val1
)
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
tokenizer
,
args
,
maxlen
=
None
,
type_ids
=
None
):
actual_model
=
model
if
isinstance
(
actual_model
,
DDP
):
actual_model
=
actual_model
.
module
if
isinstance
(
actual_model
,
FP16_Module
):
a
ctual_model
=
actual_model
.
module
original_output_parallel
=
actual_model
.
parallel_output
actual_model
.
parallel_output
=
False
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
maxlen
=
None
,
type_ids
=
None
):
"""XXX"""
a
rgs
=
get_args
()
tokenizer
=
get_tokenizer
()
model
.
eval
()
with
torch
.
no_grad
():
context_length
=
context_lengths
.
min
().
item
()
...
...
@@ -395,7 +376,11 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
while
context_length
<=
(
maxlen
):
if
args
.
recompute
:
logits
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
type_ids
)
logits
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
type_ids
,
forward_method_parallel_output
=
False
)
logits
=
logits
[:,
context_length
-
1
,
:]
else
:
types2use
=
None
...
...
@@ -405,11 +390,20 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
:
context_length
]
else
:
tokens2use
=
tokens
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
positions2use
=
position_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
tokens2use
=
tokens
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
positions2use
=
position_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
logits
,
layer_past
=
model
(
tokens2use
,
positions2use
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
True
,
tokentype_ids
=
types2use
)
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
logits
,
layer_past
=
model
(
tokens2use
,
positions2use
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
True
,
tokentype_ids
=
types2use
,
forward_method_parallel_output
=
False
)
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
args
.
greedy
:
...
...
@@ -417,15 +411,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
else
:
logits
=
logits
.
float
()
logits
/=
args
.
temperature
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
)
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
print_logits
=
[]
for
p
in
prev
:
print_logits
.
append
([
logits
[
i
,
p
].
item
()
for
i
in
range
(
batch_size
)])
print_logits
.
append
([
logits
[
i
,
p
].
item
()
for
i
in
range
(
batch_size
)])
started
=
context_lengths
<=
context_length
tokens
[:,
context_length
]
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
]
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
context_length
+=
1
counter
+=
1
...
...
@@ -439,75 +436,54 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
yield
tokens
,
lengths
if
done
:
break
actual_model
.
parallel_output
=
original_output_parallel
def
prepare_tokenizer
(
args
):
def
add_text_generate_args
(
parser
):
"""Text generate arguments."""
group
=
parser
.
add_argument_group
(
'Text generation'
,
'configurations'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
)
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
None
,
help
=
'get input from file instead of interactive mode, '
'each line is an input'
)
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
None
,
help
=
'output file got from --sample-input-file'
)
group
.
add_argument
(
"--num-samples"
,
type
=
int
,
default
=
0
,
help
=
'number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling'
)
group
.
add_argument
(
"--genfile"
,
type
=
str
,
help
=
'output file when generating unconditionally'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'during generation recompute all attention '
'instead of using previously computed keys/values.'
)
return
parser
tokenizer_args
=
{
'tokenizer_type'
:
args
.
tokenizer_type
,
'corpus'
:
None
,
'model_path'
:
args
.
tokenizer_path
,
'vocab_size'
:
args
.
vocab_size
,
'model_type'
:
args
.
tokenizer_model_type
,
'cache_dir'
:
args
.
cache_dir
}
tokenizer
=
make_tokenizer
(
**
tokenizer_args
)
args
.
tokenizer_num_tokens
=
tokenizer
.
num_tokens
args
.
tokenizer_num_type_tokens
=
tokenizer
.
num_type_tokens
args
.
eod_token
=
tokenizer
.
get_command
(
'eos'
).
Id
after
=
tokenizer
.
num_tokens
multiple
=
args
.
make_vocab_size_divisible_by
*
\
mpu
.
get_model_parallel_world_size
()
if
multiple
!=
0
:
while
(
after
%
multiple
)
!=
0
:
after
+=
1
args
.
vocab_size
=
after
print
(
"prepare tokenizer done"
,
flush
=
True
)
return
tokenizer
def
main
():
"""Main
training
program."""
"""Main program."""
print
(
'Generate Samples'
)
# Disable CuDNN.
torch
.
backends
.
cudnn
.
enabled
=
False
# Timer.
timers
=
Timers
()
# Arguments.
args
=
get_args
()
# Pytorch distributed.
initialize_distributed
(
args
)
# Random seeds for reproducability.
set_random_seed
(
args
.
seed
)
#get the tokenizer
tokenizer
=
prepare_tokenizer
(
args
)
# Model, optimizer, and learning rate.
model
=
setup_model
(
args
)
#setting default batch size to 1
# args.batch_size = 1
args
.
device
=
torch
.
cuda
.
current_device
()
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
#generate samples
if
args
.
num_samples
==
0
:
args
.
batch_size
=
1
if
args
.
sample_input_file
!=
""
:
generate_samples_input_from_file
(
model
,
tokenizer
,
args
)
generate_samples_input_from_file
(
model
)
else
:
generate_samples_interactive
(
model
,
tokenizer
,
args
)
generate_samples_interactive
(
model
)
else
:
write_and_generate_samples_unconditional
(
model
,
tokenizer
,
args
)
write_and_generate_samples_unconditional
(
model
)
if
__name__
==
"__main__"
:
...
...
megatron/arguments.py
View file @
ce29d4d5
...
...
@@ -357,29 +357,7 @@ def _add_gpt2_args(parser):
def
add_text_generate_args
(
parser
):
"""Text generate arguments."""
group
=
parser
.
add_argument_group
(
'Text generation'
,
'configurations'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
)
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
""
,
help
=
'get input from file instead of interactive mode, '
'each line is an input'
)
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
""
,
help
=
'output file got from --sample-input-file'
)
group
.
add_argument
(
"--num-samples"
,
type
=
int
,
default
=
0
,
help
=
'number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling'
)
group
.
add_argument
(
"--genfile"
,
type
=
str
,
help
=
'output file when generating unconditionally'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'during generation recompute all attention '
'instead of using previously computed keys/values.'
)
return
parser
def
add_data_args_
(
parser
):
...
...
megatron/model/bert_model.py
View file @
ce29d4d5
...
...
@@ -137,8 +137,7 @@ class BertModel(MegatronModule):
self
.
_binary_head_key
=
'binary_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
...
...
megatron/model/gpt2_model.py
View file @
ce29d4d5
...
...
@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule):
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
# Language model.
lm_output
=
self
.
language_model
(
input_ids
,
...
...
@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule):
lm_output
,
presents
=
lm_output
# Output.
parallel_output
=
self
.
parallel_output
if
forward_method_parallel_output
is
not
None
:
parallel_output
=
forward_method_parallel_output
output
=
parallel_lm_logits
(
lm_output
,
self
.
language_model
.
embedding
.
word_embeddings
.
weight
,
self
.
parallel_output
)
parallel_output
)
if
get_key_value
:
output
=
[
output
,
presents
]
...
...
megatron/tokenizer/tokenizer.py
View file @
ce29d4d5
...
...
@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC):
def
tokenize
(
self
,
text
):
pass
def
detokenize
(
self
,
token_ids
):
raise
NotImplementedError
(
'detokenizer is not implemented for {} '
'tokenizer'
.
format
(
self
.
name
))
@
property
def
cls
(
self
):
raise
NotImplementedError
(
'CLS is not provided for {} '
...
...
@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer):
def
tokenize
(
self
,
text
):
return
self
.
tokenizer
.
encode
(
text
)
def
detokenize
(
self
,
token_ids
):
return
self
.
tokenizer
.
decode
(
token_ids
)
@
property
def
eod
(
self
):
return
self
.
eod_id
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment