Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a6ba254f
Commit
a6ba254f
authored
Apr 02, 2020
by
Mohammad
Browse files
generate samples linted
parent
a19820b1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
82 deletions
+72
-82
generate_samples.py
generate_samples.py
+72
-82
pretrain_gpt2.py
pretrain_gpt2.py
+0
-0
No files found.
generate_samples.py
100755 → 100644
View file @
a6ba254f
...
...
@@ -15,34 +15,27 @@
"""Sample Generate GPT2"""
import
os
import
random
import
json
import
copy
import
numpy
as
np
import
json
import
os
import
time
import
torch
import
torch.nn.functional
as
F
import
argparse
import
time
from
arguments
import
get_args
from
megatron.utils
import
Timers
from
megatron.utils
import
initialize_distributed
from
megatron.utils
import
set_random_seed
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
load_checkpoint
from
megatron.data_utils
import
make_tokenizer
from
configure_data
import
configure_data
from
megatron
import
mpu
from
megatron
.fp16
import
FP16_Module
from
megatron
.model
import
GPT2Model
from
megatron
.model
import
DistributedDataParallel
as
DDP
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPT2Model
from
megatron.training
import
get_model
from
megatron.utils
import
get_ltor_masks_and_position_ids
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
...
...
@@ -56,7 +49,7 @@ def get_batch(context_tokens):
tokenizer
=
get_tokenizer
()
# Move to GPU.
tokens
=
context_tokens
.
view
(
args
.
batch_size
,
-
1
).
.
contiguous
().
cuda
()
tokens
=
context_tokens
.
view
(
args
.
batch_size
,
-
1
).
contiguous
().
cuda
()
# Get the attention mask and postition ids.
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
...
...
@@ -80,7 +73,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
if
top_p
>
0.0
:
# Cconvert to 1D
sorted_logits
,
sorted_indices
=
torch
.
sort
(
...
...
@@ -98,12 +91,12 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
for
i
in
range
(
sorted_indices
.
size
(
0
)):
indices_to_remove
=
sorted_indices
[
i
][
sorted_indices_to_remove
[
i
]]
logits
[
i
][
indices_to_remove
]
=
filter_value
return
logits
def
generate_samples_input_from_file
(
model
):
"""XXX"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
...
...
@@ -118,15 +111,15 @@ def generate_samples_input_from_file(model):
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'could not find `sample-output-file`, setting '
'it to {}'
.
formatsample_output_file
(
))
'it to {}'
.
format
(
sample_output_file
))
fname_out
=
open
(
sample_output_file
,
"w+"
)
context_count
=
0
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
0
terminate_runs
=
0
if
mpu
.
get_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
...
...
@@ -148,7 +141,7 @@ def generate_samples_input_from_file(model):
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
...
...
@@ -158,9 +151,8 @@ def generate_samples_input_from_file(model):
if
terminate_runs
==
1
:
return
start_time
=
time
.
time
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
...
...
@@ -176,24 +168,24 @@ def generate_samples_input_from_file(model):
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
context_count
+=
1
def
generate_samples_interactive
(
model
,
print_frequency
=
24
):
"""XXX"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_count
=
0
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
0
terminate_runs
=
0
if
mpu
.
get_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
...
...
@@ -201,7 +193,7 @@ def generate_samples_interactive(model, print_frequency=24):
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
if
"stop"
in
raw_text
:
terminate_runs
=
1
else
:
...
...
@@ -216,7 +208,7 @@ def generate_samples_interactive(model, print_frequency=24):
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
...
...
@@ -226,7 +218,6 @@ def generate_samples_interactive(model, print_frequency=24):
if
terminate_runs
==
1
:
return
start_time
=
time
.
time
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
decode_tokens
,
_
=
decode_tokens
...
...
@@ -250,20 +241,19 @@ def generate_samples_interactive(model, print_frequency=24):
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
context_count
+=
1
if
mpu
.
get_model_parallel_rank
()
==
0
:
input
(
"
\n
Press any key to continue >>>"
)
def
generate_samples_unconditional
(
model
):
"""XXX"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
num_samples
=
args
.
num_samples
context_tokens
=
[[
tokenizer
.
eod
]
for
_
in
range
(
args
.
batch_size
)]
samples
=
[]
ctr
=
0
while
True
:
start_time
=
time
.
time
()
...
...
@@ -291,6 +281,7 @@ def generate_samples_unconditional(model):
def
write_and_generate_samples_unconditional
(
model
):
args
=
get_args
()
assert
args
.
genfile
is
not
None
with
open
(
args
.
genfile
,
'w'
)
as
f
:
...
...
@@ -298,8 +289,8 @@ def write_and_generate_samples_unconditional(model):
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
def
pad_batch
(
batch
,
tokenizer
,
args
):
pad_id
=
tokenizer
.
eod
def
pad_batch
(
batch
,
pad_id
,
args
):
context_lengths
=
[]
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
...
...
@@ -310,11 +301,12 @@ def pad_batch(batch, tokenizer, args):
def
get_token_stream
(
model
,
context_tokens
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
pad_id
=
tokenizer
.
eod
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
,
args
)
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
.
eod
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
...
...
@@ -327,12 +319,7 @@ def get_token_stream(model, context_tokens):
group
=
mpu
.
get_model_parallel_group
())
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
,
args
)
counter
=
0
org_context_length
=
context_length
layer_past
=
None
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
,
args
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
...
...
@@ -343,21 +330,22 @@ def get_token_stream(model, context_tokens):
def
switch
(
val1
,
val2
,
boolean
):
boolean
=
boolean
.
type_as
(
val1
)
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
maxlen
=
None
,
type_ids
=
None
):
"""XXX"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
model
.
eval
()
with
torch
.
no_grad
():
context_length
=
context_lengths
.
min
().
item
()
eos_id
=
tokenizer
.
get_command
(
'eos'
).
I
d
eos_id
=
tokenizer
.
eo
d
counter
=
0
org_context_length
=
context_length
...
...
@@ -372,7 +360,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
maxlen
=
org_context_length
+
args
.
out_seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<=
(
maxlen
):
if
args
.
recompute
:
...
...
@@ -404,7 +392,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
get_key_value
=
True
,
tokentype_ids
=
types2use
,
forward_method_parallel_output
=
False
)
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
args
.
greedy
:
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
...
...
@@ -429,7 +417,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
was_done
=
is_done
is_done
=
is_done
|
done_token
done
=
torch
.
all
(
is_done
)
...
...
@@ -438,56 +425,59 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
break
def
add_text_generate_args
(
parser
):
"""Text generate arguments."""
group
=
parser
.
add_argument_group
(
'Text generation'
,
'configurations'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
)
"""Text generation arguments."""
group
=
parser
.
add_argument_group
(
title
=
'text generation'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
'Sampling temperature.'
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use greedy sampling.'
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
,
help
=
'Top p sampling.'
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
,
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
None
,
help
=
'
g
et input from file instead of interactive mode, '
'each line is an input'
)
help
=
'
G
et input from file instead of interactive mode, '
'each line is an input
.
'
)
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
None
,
help
=
'
o
utput file got from --sample-input-file'
)
help
=
'
O
utput file got from --sample-input-file'
)
group
.
add_argument
(
"--num-samples"
,
type
=
int
,
default
=
0
,
help
=
'
n
umber of samples to generate unconditionally, '
help
=
'
N
umber of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling'
)
group
.
add_argument
(
"--genfile"
,
type
=
str
,
help
=
'
o
utput file when generating unconditionally'
)
help
=
'
O
utput file when generating unconditionally'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'
d
uring generation recompute all attention '
help
=
'
D
uring generation recompute all attention '
'instead of using previously computed keys/values.'
)
return
parser
def
main
():
"""Main program."""
print
(
'Generate Samples'
)
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
args
=
get_args
()
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
#
g
enerate samples
#
G
enerate samples
.
if
args
.
num_samples
==
0
:
args
.
batch_size
=
1
assert
args
.
batch_size
=
=
1
if
args
.
sample_input_file
!=
""
:
generate_samples_input_from_file
(
model
)
else
:
generate_samples_interactive
(
model
)
else
:
write_and_generate_samples_unconditional
(
model
)
if
__name__
==
"__main__"
:
main
()
if
__name__
==
"__main__"
:
main
()
pretrain_gpt2.py
100755 → 100644
View file @
a6ba254f
File mode changed from 100755 to 100644
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment