Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3fe6821a
Commit
3fe6821a
authored
Aug 27, 2021
by
Ryan Prenger
Committed by
Jared Casper
Aug 27, 2021
Browse files
Adding API server
parent
136d63cb
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
356 additions
and
353 deletions
+356
-353
examples/run_text_generation_server_345M.sh
examples/run_text_generation_server_345M.sh
+32
-0
examples/run_text_generation_server_345M_8_tensor_parallel.sh
...ples/run_text_generation_server_345M_8_tensor_parallel.sh
+32
-0
megatron/initialize.py
megatron/initialize.py
+3
-2
megatron/text_generation_server.py
megatron/text_generation_server.py
+66
-0
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+104
-351
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+85
-0
tools/text_generation_cli.py
tools/text_generation_cli.py
+34
-0
No files found.
examples/run_text_generation_server_345M.sh
0 → 100755
View file @
3fe6821a
#!/bin/bash
# This example will start serving the 345M model.
DISTRIBUTED_ARGS
=
"--nproc_per_node 1
\
--nnodes 1
\
--node_rank 0
\
--master_addr localhost
\
--master_port 6000"
CHECKPOINT
=
<Path to checkpoint
(
e.g /345m
)>
VOCAB_FILE
=
<Path to vocab.json
(
e.g. /gpt2-vocab.json
)>
MERGE_FILE
=
<Path to merges.txt
(
e.g. /gpt2-merges.txt
)>
pip
install
flask-restful
python
-m
torch.distributed.launch
$DISTRIBUTED_ARGS
tools/run_text_generation_server.py /
--tensor-model-parallel-size
1 /
--pipeline-model-parallel-size
1 /
--num-layers
24 /
--hidden-size
1024 /
--load
${
CHECKPOINT
}
/
--num-attention-heads
16 /
--max-position-embeddings
1024 /
--tokenizer-type
GPT2BPETokenizer /
--fp16
/
--micro-batch-size
1 /
--seq-length
1024 /
--out-seq-length
1024 /
--temperature
1.0 /
--vocab-file
$VOCAB_FILE
/
--merge-file
$MERGE_FILE
/
--top_p
0.9 /
--seed
42
examples/run_text_generation_server_345M_8_tensor_parallel.sh
0 → 100755
View file @
3fe6821a
#!/bin/bash
# This example will start serving the 345M model that is partitioned 8 way tensor parallel
DISTRIBUTED_ARGS
=
"--nproc_per_node 8
\
--nnodes 1
\
--node_rank 0
\
--master_addr localhost
\
--master_port 6000"
CHECKPOINT
=
<Path to checkpoint
(
e.g /345m
)>
VOCAB_FILE
=
<Path to vocab.json
(
e.g. /gpt2-vocab.json
)>
MERGE_FILE
=
<Path to merges.txt
(
e.g. /gpt2-merges.txt
)>
pip
install
flask-restful
python
-m
torch.distributed.launch
$DISTRIBUTED_ARGS
tools/run_text_generation_server.py /
--tensor-model-parallel-size
8 /
--pipeline-model-parallel-size
1 /
--num-layers
24 /
--hidden-size
1024 /
--load
${
CHECKPOINT
}
/
--num-attention-heads
16 /
--max-position-embeddings
1024 /
--tokenizer-type
GPT2BPETokenizer /
--fp16
/
--micro-batch-size
1 /
--seq-length
1024 /
--out-seq-length
1024 /
--temperature
1.0 /
--vocab-file
$VOCAB_FILE
/
--merge-file
$MERGE_FILE
/
--top_p
0.9 /
--seed
42
megatron/initialize.py
View file @
3fe6821a
...
@@ -21,6 +21,7 @@ import time
...
@@ -21,6 +21,7 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
datetime
import
timedelta
from
megatron
import
fused_kernels
from
megatron
import
fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
...
@@ -175,8 +176,8 @@ def _initialize_distributed():
...
@@ -175,8 +176,8 @@ def _initialize_distributed():
# Call the init process
# Call the init process
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
timeout
=
timedelta
(
days
=
7
))
# Set the tensor model-parallel, pipeline model-parallel, and
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
# data-parallel communicators.
...
...
megatron/text_generation_server.py
0 → 100644
View file @
3fe6821a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
flask
import
Flask
,
request
,
jsonify
,
current_app
from
flask_restful
import
Resource
,
Api
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.text_generation_utils
import
generate
GENERATE_NUM
=
0
class
MegatronGenerate
(
Resource
):
def
__init__
(
self
,
model
):
self
.
model
=
model
@
staticmethod
def
send_do_generate
():
choice
=
torch
.
cuda
.
LongTensor
([
GENERATE_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
def
put
(
self
):
args
=
get_args
()
sentences
=
request
.
get_json
()[
"sentences"
]
if
len
(
sentences
)
>
128
:
return
"Maximum number of sentences is 128"
,
400
max_len
=
64
# Choosing hopefully sane default. Full sequence is slow
if
"max_len"
in
request
.
get_json
():
max_len
=
request
.
get_json
()[
"max_len"
]
if
not
isinstance
(
max_len
,
int
):
return
"max_len must be an integer greater than 0"
if
max_len
<
1
:
return
"max_len must be an integer greater than 0"
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
resp_sentences
=
generate
(
self
.
model
,
sentences
,
max_len
)
return
jsonify
({
"sentences"
:
resp_sentences
})
def
index
():
return
current_app
.
send_static_file
(
'index.html'
)
class
MegatronServer
(
object
):
def
__init__
(
self
,
model
):
self
.
app
=
Flask
(
__name__
)
self
.
app
.
add_url_rule
(
'/'
,
'index'
,
index
)
api
=
Api
(
self
.
app
)
api
.
add_resource
(
MegatronGenerate
,
'/generate'
,
resource_class_args
=
[
model
])
def
run
(
self
,
url
):
self
.
app
.
run
(
url
,
threaded
=
False
,
debug
=
False
)
megatron/text_generation_utils.py
View file @
3fe6821a
...
@@ -40,7 +40,8 @@ def get_batch(context_tokens):
...
@@ -40,7 +40,8 @@ def get_batch(context_tokens):
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
# Move to GPU.
# Move to GPU.
tokens
=
context_tokens
.
view
(
args
.
micro_batch_size
,
-
1
).
contiguous
().
cuda
()
tokens
=
context_tokens
.
contiguous
().
cuda
()
# Get the attention mask and postition ids.
# Get the attention mask and postition ids.
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
tokens
,
...
@@ -84,301 +85,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
...
@@ -84,301 +85,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
return
logits
return
logits
def
generate_samples_input_from_file
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_pos
=
0
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'`sample-output-file` not specified, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
fname_out
=
open
(
sample_output_file
,
"w+"
)
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
terminate_runs
=
0
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
input_pos
+=
1
if
input_pos
==
input_count
:
raw_text
=
"stop"
raw_text_len
=
len
(
raw_text
)
if
"stop"
in
raw_text
:
terminate_runs
=
1
else
:
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
context_length
=
len
(
context_tokens
)
if
context_length
>=
(
args
.
seq_length
//
2
):
print
(
"
\n
Context length"
,
context_length
,
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
0
input_info
=
[
terminate_runs
,
raw_text_len
,
context_length
]
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
torch
.
distributed
.
all_reduce
(
input_info_tensor
,
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
context_length
=
input_info_tensor
[
2
].
item
()
if
terminate_runs
==
1
:
return
# For pipeline parallel we send context tokens to other stages
# so they get the lengths correct
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
else
:
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
fname_out
.
write
(
"
\n
Context:"
)
fname_out
.
write
(
raw_text
)
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
context_count
+=
1
# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness
# codebase. The lm-evaluation-harness code can now call this function
# similar to their current generate function call used for gpt style models.
def
generate_samples_eval
(
model
,
context
,
max_gen_length
,
eos_token_id
):
# Generate samples for lm evaluation
# NEED TO THINK ABOUT eos token
args
=
get_args
()
tokenizer
=
get_tokenizer
()
raw_text_len
=
len
(
context
)
model
.
eval
()
context_tokens
=
tokenizer
.
tokenize
(
context
)
args
.
out_seq_length
=
max_gen_length
+
len
(
context_tokens
)
args
.
eos_id
=
eos_token_id
with
torch
.
no_grad
():
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
if
counter
==
args
.
out_seq_length
:
break
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
return
trim_decode_tokens
def
generate_samples_interactive
(
model
,
print_frequency
=
24
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
terminate_runs
=
0
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
raw_text_len
=
len
(
raw_text
)
if
"stop"
in
raw_text
:
terminate_runs
=
1
else
:
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
context_length
=
len
(
context_tokens
)
if
context_length
>=
(
args
.
seq_length
//
2
):
print
(
"
\n
Context length"
,
context_length
,
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
0
input_info
=
[
terminate_runs
,
raw_text_len
,
context_length
]
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
torch
.
distributed
.
all_reduce
(
input_info_tensor
,
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
context_length
=
input_info_tensor
[
2
].
item
()
if
terminate_runs
==
1
:
return
# For pipeline parallel we send context tokens to other stages
# so they get the lengths correct
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
else
:
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
if
counter
%
print_frequency
!=
0
\
or
mpu
.
get_tensor_model_parallel_rank
()
!=
0
\
or
not
mpu
.
is_pipeline_first_stage
():
continue
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
if
not
isinstance
(
decode_tokens
,
list
):
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
input
(
"
\n
Press Enter to continue >>>"
)
raw_text
=
None
context_count
+=
1
def
generate_samples_unconditional
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
num_samples
=
args
.
num_samples
context_tokens
=
[[
tokenizer
.
eod
]
for
_
in
range
(
args
.
micro_batch_size
)]
ctr
=
0
while
True
:
start_time
=
time
.
time
()
for
token_stream
in
get_token_stream
(
model
,
copy
.
deepcopy
(
context_tokens
)):
pass
if
mpu
.
is_pipeline_last_stage
()
and
\
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
ctr
%
args
.
log_interval
==
0
:
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
start_time
=
time
.
time
()
length
=
len
(
token_stream
)
token_batch
=
token_stream
[
0
].
cpu
().
numpy
().
tolist
()
length_batch
=
token_stream
[
1
].
cpu
().
numpy
().
tolist
()
assert
len
(
length_batch
)
==
args
.
micro_batch_size
for
tokens
,
length
in
zip
(
token_batch
,
length_batch
):
tokens
=
tokens
[
1
:
length
-
1
]
text
=
tokenizer
.
detokenize
(
tokens
)
is_finished
=
length
<
args
.
seq_length
-
1
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
yield
datum
ctr
+=
1
if
ctr
>=
num_samples
:
break
else
:
for
_
in
range
(
args
.
micro_batch_size
):
yield
None
ctr
+=
1
if
ctr
>=
num_samples
:
break
if
ctr
>=
num_samples
:
break
def
generate_and_write_samples_unconditional
(
model
):
args
=
get_args
()
assert
args
.
genfile
is
not
None
with
open
(
args
.
genfile
,
'w'
)
as
f
:
for
datum
in
generate_samples_unconditional
(
model
):
if
mpu
.
is_pipeline_last_stage
()
and
\
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
def
pad_batch
(
batch
,
pad_id
,
args
):
def
pad_batch
(
batch
,
pad_id
,
args
):
context_lengths
=
[]
context_lengths
=
[]
for
tokens
in
batch
:
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
context_length
=
len
(
tokens
)
...
@@ -387,41 +94,94 @@ def pad_batch(batch, pad_id, args):
...
@@ -387,41 +94,94 @@ def pad_batch(batch, pad_id, args):
context_lengths
.
append
(
context_length
)
context_lengths
.
append
(
context_length
)
return
batch
,
context_lengths
return
batch
,
context_lengths
def
tokenize_batch
(
sentences
):
def
get_token_stream
(
model
,
context_tokens
):
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
context_tokens
=
[
tokenizer
.
tokenize
(
s
)
for
s
in
sentences
]
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
.
eod
,
args
)
tokenizer
.
eod
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
return
context_tokens_tensor
,
context_length_tensor
def
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
max_len
):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
max_len
]
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
# Send variables to all ranks
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
def
receive_generate_info
():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor
=
torch
.
empty
(
3
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
batch_size
=
input_info_tensor
[
0
].
item
()
seq_len
=
input_info_tensor
[
1
].
item
()
max_len
=
input_info_tensor
[
2
].
item
()
torch
.
distributed
.
broadcast
(
context_length_tensor
,
context_length_tensor
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
mpu
.
get_tensor_model_parallel_src_rank
(),
context_tokens_tensor
=
torch
.
empty
(
batch_size
,
seq_len
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
# Send variables to all ranks
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
return
context_length_tensor
,
context_tokens_tensor
,
max_len
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
max_len
):
context_length
=
context_length_tensor
.
min
().
item
()
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
)
attention_mask
,
position_ids
,
max_len
)
for
tokens
,
lengths
in
batch_token_iterator
:
for
tokens
,
lengths
in
batch_token_iterator
:
context_length
+=
1
context_length
+=
1
if
tokens
is
not
None
:
if
tokens
is
not
None
:
yield
tokens
[:,
:
context_length
],
lengths
return
tokens
[:,
:
context_length
]
def
generate
(
model
,
sentences
=
None
,
max_len
=
0
):
model
.
eval
()
if
torch
.
distributed
.
get_rank
()
==
0
:
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
max_len
)
else
:
else
:
yield
None
,
None
context_length_tensor
,
context_tokens_tensor
,
max_len
=
receive_generate_info
()
decode_tokens
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
max_len
)
def
switch
(
val1
,
val2
,
boolean
):
if
torch
.
distributed
.
get_rank
()
==
0
:
args
=
get_args
()
tokenizer
=
get_tokenizer
()
resp_sentences
=
[]
for
i
in
range
(
decode_tokens
.
size
(
0
)):
decode_token
=
decode_tokens
[
i
,:].
cpu
().
numpy
().
tolist
()
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
return
resp_sentences
def
generate_samples_eval
(
model
,
context
,
max_gen_length
,
eos_token_id
):
"""
This function is here to provide an a matching API for a legacy task
This implementation hasn't been tested yet to make sure it matches
"""
assert
False
,
"Implementation untested"
args
=
get_args
()
args
.
eos_id
=
eos_token_id
raw_text_len
=
len
(
context
)
resp_sentences
=
generate
(
model
,
[
context
],
max_gen_length
)
return
resp_sentences
[
0
][
raw_text_len
:]
def
switch
(
val1
,
val2
,
boolean
):
boolean
=
boolean
.
type_as
(
val1
)
boolean
=
boolean
.
type_as
(
val1
)
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
...
@@ -435,6 +195,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -435,6 +195,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
args
=
get_args
()
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
shape
[
1
]
args
.
seq_length
=
tokens
.
shape
[
1
]
args
.
micro_batch_size
=
tokens
.
shape
[
0
]
input_tensor
=
recv_forward
()
input_tensor
=
recv_forward
()
...
@@ -462,7 +223,6 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -462,7 +223,6 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
attention_mask
,
position_ids
,
maxlen
=
None
,
type_ids
=
None
):
maxlen
=
None
,
type_ids
=
None
):
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -486,22 +246,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -486,22 +246,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
tokens
=
context_tokens
tokens
=
context_tokens
if
maxlen
is
None
:
if
maxlen
is
None
:
maxlen
=
args
.
seq_length
-
1
maxlen
=
args
.
seq_length
-
1
maxlen
=
maxlen
+
org_context_length
if
maxlen
>
(
org_context_length
+
args
.
out_seq_length
):
if
maxlen
>
(
org_context_length
+
args
.
out_seq_length
):
maxlen
=
org_context_length
+
args
.
out_seq_length
maxlen
=
org_context_length
+
args
.
out_seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<=
(
maxlen
):
while
context_length
<
maxlen
:
if
args
.
recompute
:
output
=
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
type_ids
,
forward_method_parallel_output
=
False
)
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
logits
=
output
[:,
context_length
-
1
,
:]
else
:
types2use
=
None
types2use
=
None
if
counter
==
0
:
if
counter
==
0
:
tokens2use
=
tokens
[:,
:
context_length
]
tokens2use
=
tokens
[:,
:
context_length
]
...
...
tools/run_text_generation_server.py
0 → 100644
View file @
3fe6821a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT"""
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
import
socket
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.text_generation_server
import
MegatronServer
from
megatron.text_generation_utils
import
generate
import
torch
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
False
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
add_text_generate_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'text generation'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
'Sampling temperature.'
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use greedy sampling.'
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
,
help
=
'Top p sampling.'
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
,
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
return
parser
if
__name__
==
"__main__"
:
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint
model
=
get_model
(
model_provider
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
server
=
MegatronServer
(
model
)
server
.
run
(
"0.0.0.0"
)
while
True
:
choice
=
torch
.
cuda
.
LongTensor
(
1
)
torch
.
distributed
.
broadcast
(
choice
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
choice
[
0
].
item
()
==
0
:
generate
(
model
)
tools/text_generation_cli.py
0 → 100644
View file @
3fe6821a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
sys
import
urllib2
class
PutRequest
(
urllib2
.
Request
):
'''class to handling putting with urllib2'''
def
get_method
(
self
,
*
args
,
**
kwargs
):
return
'PUT'
if
__name__
==
"__main__"
:
url
=
sys
.
argv
[
1
]
while
True
:
sentence
=
raw_input
(
"Enter prompt: "
)
max_len
=
int
(
input
(
"Enter number tokens output: "
))
data
=
json
.
dumps
({
"sentences"
:
[
sentence
],
"max_len"
:
max_len
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
response
=
urllib2
.
urlopen
(
req
)
resp_sentences
=
json
.
load
(
response
)
print
(
"Megatron Response: "
)
print
(
resp_sentences
[
"sentences"
][
0
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment