Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f7fe3865
Commit
f7fe3865
authored
Jun 30, 2021
by
rprenger
Browse files
Clean up removing a lot of code. Works with curl people might still want a webpage or CLI
parent
a9a3ef50
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
350 deletions
+15
-350
megatron/api_server.py
megatron/api_server.py
+2
-14
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+11
-316
run_api_server_8.3B.sh
run_api_server_8.3B.sh
+1
-4
tools/run_api_server.py
tools/run_api_server.py
+1
-16
No files found.
megatron/api_server.py
View file @
f7fe3865
...
...
@@ -19,22 +19,10 @@ from flask_restful import Resource, Api
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.text_generation_utils
import
pad_batch
from
megatron.text_generation_utils
import
get_token_stream2
from
megatron.text_generation_utils
import
tokenize_batch
,
get_token_stream
GENERATE_NUM
=
0
def
tokenize_batch
(
sentences
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_tokens
=
[
tokenizer
.
tokenize
(
s
)
for
s
in
sentences
]
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
.
eod
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
return
context_tokens_tensor
,
context_length_tensor
class
MegatronGenerate
(
Resource
):
def
__init__
(
self
,
model
):
self
.
model
=
model
...
...
@@ -82,7 +70,7 @@ class MegatronGenerate(Resource):
@
staticmethod
def
do_generate
(
model
,
context_length_tensor
,
context_tokens_tensor
,
max_len
):
token_stream
=
get_token_stream
2
(
model
,
context_tokens_tensor
,
context_length_tensor
)
token_stream
=
get_token_stream
(
model
,
context_tokens_tensor
,
context_length_tensor
)
for
i
,
decode_tokens
in
enumerate
(
token_stream
):
if
i
==
max_len
-
1
:
break
...
...
megatron/text_generation_utils.py
View file @
f7fe3865
...
...
@@ -85,301 +85,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
return
logits
def
generate_samples_input_from_file
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_pos
=
0
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'`sample-output-file` not specified, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
fname_out
=
open
(
sample_output_file
,
"w+"
)
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
terminate_runs
=
0
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
input_pos
+=
1
if
input_pos
==
input_count
:
raw_text
=
"stop"
raw_text_len
=
len
(
raw_text
)
if
"stop"
in
raw_text
:
terminate_runs
=
1
else
:
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
context_length
=
len
(
context_tokens
)
if
context_length
>=
(
args
.
seq_length
//
2
):
print
(
"
\n
Context length"
,
context_length
,
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
0
input_info
=
[
terminate_runs
,
raw_text_len
,
context_length
]
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
torch
.
distributed
.
all_reduce
(
input_info_tensor
,
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
context_length
=
input_info_tensor
[
2
].
item
()
if
terminate_runs
==
1
:
return
# For pipeline parallel we send context tokens to other stages
# so they get the lengths correct
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
else
:
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
fname_out
.
write
(
"
\n
Context:"
)
fname_out
.
write
(
raw_text
)
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
context_count
+=
1
# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness
# codebase. The lm-evaluation-harness code can now call this function
# similar to their current generate function call used for gpt style models.
def
generate_samples_eval
(
model
,
context
,
max_gen_length
,
eos_token_id
):
# Generate samples for lm evaluation
# NEED TO THINK ABOUT eos token
args
=
get_args
()
tokenizer
=
get_tokenizer
()
raw_text_len
=
len
(
context
)
model
.
eval
()
context_tokens
=
tokenizer
.
tokenize
(
context
)
args
.
out_seq_length
=
max_gen_length
+
len
(
context_tokens
)
args
.
eos_id
=
eos_token_id
with
torch
.
no_grad
():
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
if
counter
==
args
.
out_seq_length
:
break
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
return
trim_decode_tokens
def
generate_samples_interactive
(
model
,
print_frequency
=
24
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
terminate_runs
=
0
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
raw_text_len
=
len
(
raw_text
)
if
"stop"
in
raw_text
:
terminate_runs
=
1
else
:
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
context_length
=
len
(
context_tokens
)
if
context_length
>=
(
args
.
seq_length
//
2
):
print
(
"
\n
Context length"
,
context_length
,
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
0
input_info
=
[
terminate_runs
,
raw_text_len
,
context_length
]
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
torch
.
distributed
.
all_reduce
(
input_info_tensor
,
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
context_length
=
input_info_tensor
[
2
].
item
()
if
terminate_runs
==
1
:
return
# For pipeline parallel we send context tokens to other stages
# so they get the lengths correct
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
else
:
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
if
counter
%
print_frequency
!=
0
\
or
mpu
.
get_tensor_model_parallel_rank
()
!=
0
\
or
not
mpu
.
is_pipeline_first_stage
():
continue
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
if
not
isinstance
(
decode_tokens
,
list
):
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
input
(
"
\n
Press Enter to continue >>>"
)
raw_text
=
None
context_count
+=
1
def
generate_samples_unconditional
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
num_samples
=
args
.
num_samples
context_tokens
=
[[
tokenizer
.
eod
]
for
_
in
range
(
args
.
micro_batch_size
)]
ctr
=
0
while
True
:
start_time
=
time
.
time
()
for
token_stream
in
get_token_stream
(
model
,
copy
.
deepcopy
(
context_tokens
)):
pass
if
mpu
.
is_pipeline_last_stage
()
and
\
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
ctr
%
args
.
log_interval
==
0
:
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
start_time
=
time
.
time
()
length
=
len
(
token_stream
)
token_batch
=
token_stream
[
0
].
cpu
().
numpy
().
tolist
()
length_batch
=
token_stream
[
1
].
cpu
().
numpy
().
tolist
()
assert
len
(
length_batch
)
==
args
.
micro_batch_size
for
tokens
,
length
in
zip
(
token_batch
,
length_batch
):
tokens
=
tokens
[
1
:
length
-
1
]
text
=
tokenizer
.
detokenize
(
tokens
)
is_finished
=
length
<
args
.
seq_length
-
1
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
yield
datum
ctr
+=
1
if
ctr
>=
num_samples
:
break
else
:
for
_
in
range
(
args
.
micro_batch_size
):
yield
None
ctr
+=
1
if
ctr
>=
num_samples
:
break
if
ctr
>=
num_samples
:
break
def
generate_and_write_samples_unconditional
(
model
):
args
=
get_args
()
assert
args
.
genfile
is
not
None
with
open
(
args
.
genfile
,
'w'
)
as
f
:
for
datum
in
generate_samples_unconditional
(
model
):
if
mpu
.
is_pipeline_last_stage
()
and
\
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
def
pad_batch
(
batch
,
pad_id
,
args
):
context_lengths
=
[]
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
...
...
@@ -388,7 +94,17 @@ def pad_batch(batch, pad_id, args):
context_lengths
.
append
(
context_length
)
return
batch
,
context_lengths
def
get_token_stream2
(
model
,
context_tokens_tensor
,
context_length_tensor
):
def
tokenize_batch
(
sentences
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_tokens
=
[
tokenizer
.
tokenize
(
s
)
for
s
in
sentences
]
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
.
eod
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
return
context_tokens_tensor
,
context_length_tensor
def
get_token_stream
(
model
,
context_tokens_tensor
,
context_length_tensor
):
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
...
...
@@ -402,27 +118,6 @@ def get_token_stream2(model, context_tokens_tensor, context_length_tensor):
else
:
yield
None
,
None
def
get_token_stream
(
model
,
context_tokens
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
.
eod
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
return
get_token_stream2
(
model
,
context_tokens_tensor
,
context_length_tensor
)
def
switch
(
val1
,
val2
,
boolean
):
boolean
=
boolean
.
type_as
(
val1
)
...
...
run_api_server_8.3B.sh
View file @
f7fe3865
...
...
@@ -25,8 +25,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_api_server.py \
--temperature
1.0
\
--vocab-file
$VOCAB_FILE
\
--merge-file
$MERGE_FILE
\
--genfile
unconditional_samples.json
\
--num-samples
1
\
--top_p
0.9
\
--seed
42
\
--recompute
--seed
42
tools/run_api_server.py
View file @
f7fe3865
...
...
@@ -27,16 +27,13 @@ from megatron.checkpointing import load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.text_generation_utils
import
generate_samples_interactive
from
megatron.api_server
import
MegatronServer
from
megatron.api_server
import
MegatronGenerate
from
megatron.api_server
import
MegatronServer
,
MegatronGenerate
import
torch
def
do_generate
(
model
):
context_length_tensor
,
context_tokens_tensor
,
max_len
=
MegatronGenerate
.
receive_generate_info
()
MegatronGenerate
.
do_generate
(
model
,
context_length_tensor
,
context_tokens_tensor
,
max_len
)
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
...
...
@@ -46,7 +43,6 @@ def model_provider(pre_process=True, post_process=True):
return
model
def
add_text_generate_args
(
parser
):
"""Text generation arguments."""
group
=
parser
.
add_argument_group
(
title
=
'text generation'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
...
...
@@ -59,16 +55,6 @@ def add_text_generate_args(parser):
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
None
,
help
=
'Get input from file instead of interactive mode, '
'each line is an input.'
)
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
None
,
help
=
'Output file got from --sample-input-file'
)
group
.
add_argument
(
"--num-samples"
,
type
=
int
,
default
=
0
,
help
=
'Number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling'
)
group
.
add_argument
(
"--genfile"
,
type
=
str
,
help
=
'Output file when generating unconditionally'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'During generation recompute all attention '
'instead of using previously computed keys/values.'
)
...
...
@@ -103,6 +89,5 @@ if __name__ == "__main__":
torch
.
distributed
.
broadcast
(
choice
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
print
(
"got: "
+
str
(
choice
[
0
].
item
()))
if
choice
[
0
].
item
()
==
0
:
do_generate
(
model
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment