Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2cdebf4b
Commit
2cdebf4b
authored
Jun 23, 2021
by
rprenger
Browse files
Working single sentence version
parent
83c4d95a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
250 additions
and
12 deletions
+250
-12
megatron/api_server.py
megatron/api_server.py
+128
-0
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+14
-12
tools/run_api_server.py
tools/run_api_server.py
+108
-0
No files found.
megatron/api_server.py
0 → 100644
View file @
2cdebf4b
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
flask
import
Flask
,
request
,
jsonify
from
flask_restful
import
Resource
,
Api
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.text_generation_utils
import
pad_batch
from
megatron.text_generation_utils
import
get_token_stream2
GENERATE_NUM
=
0
def
tokenize_batch
(
sentences
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_tokens
=
[
tokenizer
.
tokenize
(
s
)
for
s
in
sentences
]
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
.
eod
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
return
context_tokens_tensor
,
context_length_tensor
class
MegatronGenerate
(
Resource
):
def
__init__
(
self
,
model
):
self
.
model
=
model
@
staticmethod
def
send_do_generate
():
choice
=
torch
.
cuda
.
LongTensor
([
GENERATE_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
@
staticmethod
def
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
max_len
):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
max_len
]
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
# Now send tensors
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
@
staticmethod
def
receive_generate_info
():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor
=
torch
.
empty
(
3
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
input_info_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
batch_size
=
input_info_tensor
[
0
].
item
()
seq_len
=
input_info_tensor
[
1
].
item
()
max_len
=
input_info_tensor
[
2
].
item
()
context_length_tensor
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
context_tokens_tensor
=
torch
.
empty
(
batch_size
,
seq_len
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
return
context_length_tensor
,
context_tokens_tensor
,
max_len
@
staticmethod
def
do_generate
(
model
,
context_length_tensor
,
context_tokens_tensor
,
max_len
):
token_stream
=
get_token_stream2
(
model
,
context_tokens_tensor
,
context_length_tensor
)
for
i
,
decode_tokens
in
enumerate
(
token_stream
):
if
i
==
max_len
-
1
:
break
pass
return
decode_tokens
def
put
(
self
):
sentences
=
request
.
get_json
()[
"sentences"
]
max_len
=
1024
# TODO (rprenger) this should not be hardcoded
if
"max_len"
in
request
.
get_json
():
max_len
=
request
.
get_json
()[
"max_len"
]
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
)
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
MegatronGenerate
.
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
max_len
)
# Send them info
decode_tokens
=
MegatronGenerate
.
do_generate
(
self
.
model
,
context_length_tensor
,
context_tokens_tensor
,
max_len
)
# Do stuff
args
=
get_args
()
tokenizer
=
get_tokenizer
()
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)
return
jsonify
({
"sentences"
:
[
trim_decode_tokens
]})
class
MegatronServer
(
object
):
def
__init__
(
self
,
model
):
self
.
app
=
Flask
(
__name__
)
api
=
Api
(
self
.
app
)
api
.
add_resource
(
MegatronGenerate
,
'/generate'
,
resource_class_args
=
[
model
])
def
run
(
self
,
url
):
self
.
app
.
run
(
url
,
debug
=
False
)
megatron/text_generation_utils.py
View file @
2cdebf4b
...
@@ -387,6 +387,19 @@ def pad_batch(batch, pad_id, args):
...
@@ -387,6 +387,19 @@ def pad_batch(batch, pad_id, args):
context_lengths
.
append
(
context_length
)
context_lengths
.
append
(
context_length
)
return
batch
,
context_lengths
return
batch
,
context_lengths
def
get_token_stream2
(
model
,
context_tokens_tensor
,
context_length_tensor
):
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
)
for
tokens
,
lengths
in
batch_token_iterator
:
context_length
+=
1
if
tokens
is
not
None
:
yield
tokens
[:,
:
context_length
],
lengths
else
:
yield
None
,
None
def
get_token_stream
(
model
,
context_tokens
):
def
get_token_stream
(
model
,
context_tokens
):
...
@@ -406,18 +419,7 @@ def get_token_stream(model, context_tokens):
...
@@ -406,18 +419,7 @@ def get_token_stream(model, context_tokens):
mpu
.
get_tensor_model_parallel_src_rank
(),
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
group
=
mpu
.
get_tensor_model_parallel_group
())
context_length
=
context_length_tensor
.
min
().
item
()
return
get_token_stream2
(
model
,
context_tokens_tensor
,
context_length_tensor
)
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
)
for
tokens
,
lengths
in
batch_token_iterator
:
context_length
+=
1
if
tokens
is
not
None
:
yield
tokens
[:,
:
context_length
],
lengths
else
:
yield
None
,
None
def
switch
(
val1
,
val2
,
boolean
):
def
switch
(
val1
,
val2
,
boolean
):
...
...
tools/run_api_server.py
0 → 100644
View file @
2cdebf4b
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT"""
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
import
socket
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.text_generation_utils
import
generate_samples_interactive
from
megatron.api_server
import
MegatronServer
from
megatron.api_server
import
MegatronGenerate
import
torch
def
do_generate
(
model
):
context_length_tensor
,
context_tokens_tensor
,
max_len
=
MegatronGenerate
.
receive_generate_info
()
MegatronGenerate
.
do_generate
(
model
,
context_length_tensor
,
context_tokens_tensor
,
max_len
)
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
False
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
add_text_generate_args
(
parser
):
"""Text generation arguments."""
group
=
parser
.
add_argument_group
(
title
=
'text generation'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
'Sampling temperature.'
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use greedy sampling.'
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
,
help
=
'Top p sampling.'
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
,
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
None
,
help
=
'Get input from file instead of interactive mode, '
'each line is an input.'
)
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
None
,
help
=
'Output file got from --sample-input-file'
)
group
.
add_argument
(
"--num-samples"
,
type
=
int
,
default
=
0
,
help
=
'Number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling'
)
group
.
add_argument
(
"--genfile"
,
type
=
str
,
help
=
'Output file when generating unconditionally'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'During generation recompute all attention '
'instead of using previously computed keys/values.'
)
return
parser
if
__name__
==
"__main__"
:
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint
model
=
get_model
(
model_provider
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
server
=
MegatronServer
(
model
)
server
.
run
(
"0.0.0.0"
)
while
True
:
choice
=
torch
.
cuda
.
LongTensor
(
1
)
torch
.
distributed
.
broadcast
(
choice
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
print
(
"got: "
+
str
(
choice
[
0
].
item
()))
if
choice
[
0
].
item
()
==
0
:
do_generate
(
model
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment