Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3fe6821a
Commit
3fe6821a
authored
Aug 27, 2021
by
Ryan Prenger
Committed by
Jared Casper
Aug 27, 2021
Browse files
Adding API server
parent
136d63cb
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
356 additions
and
353 deletions
+356
-353
examples/run_text_generation_server_345M.sh
examples/run_text_generation_server_345M.sh
+32
-0
examples/run_text_generation_server_345M_8_tensor_parallel.sh
...ples/run_text_generation_server_345M_8_tensor_parallel.sh
+32
-0
megatron/initialize.py
megatron/initialize.py
+3
-2
megatron/text_generation_server.py
megatron/text_generation_server.py
+66
-0
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+104
-351
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+85
-0
tools/text_generation_cli.py
tools/text_generation_cli.py
+34
-0
No files found.
examples/run_text_generation_server_345M.sh
0 → 100755
View file @
3fe6821a
#!/bin/bash
# This example will start serving the 345M model.
DISTRIBUTED_ARGS
=
"--nproc_per_node 1
\
--nnodes 1
\
--node_rank 0
\
--master_addr localhost
\
--master_port 6000"
CHECKPOINT
=
<Path to checkpoint
(
e.g /345m
)>
VOCAB_FILE
=
<Path to vocab.json
(
e.g. /gpt2-vocab.json
)>
MERGE_FILE
=
<Path to merges.txt
(
e.g. /gpt2-merges.txt
)>
pip
install
flask-restful
python
-m
torch.distributed.launch
$DISTRIBUTED_ARGS
tools/run_text_generation_server.py /
--tensor-model-parallel-size
1 /
--pipeline-model-parallel-size
1 /
--num-layers
24 /
--hidden-size
1024 /
--load
${
CHECKPOINT
}
/
--num-attention-heads
16 /
--max-position-embeddings
1024 /
--tokenizer-type
GPT2BPETokenizer /
--fp16
/
--micro-batch-size
1 /
--seq-length
1024 /
--out-seq-length
1024 /
--temperature
1.0 /
--vocab-file
$VOCAB_FILE
/
--merge-file
$MERGE_FILE
/
--top_p
0.9 /
--seed
42
examples/run_text_generation_server_345M_8_tensor_parallel.sh
0 → 100755
View file @
3fe6821a
#!/bin/bash
# This example will start serving the 345M model that is partitioned 8 way tensor parallel
DISTRIBUTED_ARGS
=
"--nproc_per_node 8
\
--nnodes 1
\
--node_rank 0
\
--master_addr localhost
\
--master_port 6000"
CHECKPOINT
=
<Path to checkpoint
(
e.g /345m
)>
VOCAB_FILE
=
<Path to vocab.json
(
e.g. /gpt2-vocab.json
)>
MERGE_FILE
=
<Path to merges.txt
(
e.g. /gpt2-merges.txt
)>
pip
install
flask-restful
python
-m
torch.distributed.launch
$DISTRIBUTED_ARGS
tools/run_text_generation_server.py /
--tensor-model-parallel-size
8 /
--pipeline-model-parallel-size
1 /
--num-layers
24 /
--hidden-size
1024 /
--load
${
CHECKPOINT
}
/
--num-attention-heads
16 /
--max-position-embeddings
1024 /
--tokenizer-type
GPT2BPETokenizer /
--fp16
/
--micro-batch-size
1 /
--seq-length
1024 /
--out-seq-length
1024 /
--temperature
1.0 /
--vocab-file
$VOCAB_FILE
/
--merge-file
$MERGE_FILE
/
--top_p
0.9 /
--seed
42
megatron/initialize.py
View file @
3fe6821a
...
@@ -21,6 +21,7 @@ import time
...
@@ -21,6 +21,7 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
datetime
import
timedelta
from
megatron
import
fused_kernels
from
megatron
import
fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
...
@@ -175,8 +176,8 @@ def _initialize_distributed():
...
@@ -175,8 +176,8 @@ def _initialize_distributed():
# Call the init process
# Call the init process
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
timeout
=
timedelta
(
days
=
7
))
# Set the tensor model-parallel, pipeline model-parallel, and
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
# data-parallel communicators.
...
...
megatron/text_generation_server.py
0 → 100644
View file @
3fe6821a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
flask
import
Flask
,
request
,
jsonify
,
current_app
from
flask_restful
import
Resource
,
Api
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.text_generation_utils
import
generate
GENERATE_NUM
=
0
class
MegatronGenerate
(
Resource
):
def
__init__
(
self
,
model
):
self
.
model
=
model
@
staticmethod
def
send_do_generate
():
choice
=
torch
.
cuda
.
LongTensor
([
GENERATE_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
def
put
(
self
):
args
=
get_args
()
sentences
=
request
.
get_json
()[
"sentences"
]
if
len
(
sentences
)
>
128
:
return
"Maximum number of sentences is 128"
,
400
max_len
=
64
# Choosing hopefully sane default. Full sequence is slow
if
"max_len"
in
request
.
get_json
():
max_len
=
request
.
get_json
()[
"max_len"
]
if
not
isinstance
(
max_len
,
int
):
return
"max_len must be an integer greater than 0"
if
max_len
<
1
:
return
"max_len must be an integer greater than 0"
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
resp_sentences
=
generate
(
self
.
model
,
sentences
,
max_len
)
return
jsonify
({
"sentences"
:
resp_sentences
})
def
index
():
return
current_app
.
send_static_file
(
'index.html'
)
class
MegatronServer
(
object
):
def
__init__
(
self
,
model
):
self
.
app
=
Flask
(
__name__
)
self
.
app
.
add_url_rule
(
'/'
,
'index'
,
index
)
api
=
Api
(
self
.
app
)
api
.
add_resource
(
MegatronGenerate
,
'/generate'
,
resource_class_args
=
[
model
])
def
run
(
self
,
url
):
self
.
app
.
run
(
url
,
threaded
=
False
,
debug
=
False
)
megatron/text_generation_utils.py
View file @
3fe6821a
This diff is collapsed.
Click to expand it.
tools/run_text_generation_server.py
0 → 100644
View file @
3fe6821a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT"""
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
import
socket
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.text_generation_server
import
MegatronServer
from
megatron.text_generation_utils
import
generate
import
torch
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
False
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
add_text_generate_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'text generation'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
'Sampling temperature.'
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use greedy sampling.'
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
,
help
=
'Top p sampling.'
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
,
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
return
parser
if
__name__
==
"__main__"
:
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint
model
=
get_model
(
model_provider
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
server
=
MegatronServer
(
model
)
server
.
run
(
"0.0.0.0"
)
while
True
:
choice
=
torch
.
cuda
.
LongTensor
(
1
)
torch
.
distributed
.
broadcast
(
choice
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
choice
[
0
].
item
()
==
0
:
generate
(
model
)
tools/text_generation_cli.py
0 → 100644
View file @
3fe6821a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
sys
import
urllib2
class
PutRequest
(
urllib2
.
Request
):
'''class to handling putting with urllib2'''
def
get_method
(
self
,
*
args
,
**
kwargs
):
return
'PUT'
if
__name__
==
"__main__"
:
url
=
sys
.
argv
[
1
]
while
True
:
sentence
=
raw_input
(
"Enter prompt: "
)
max_len
=
int
(
input
(
"Enter number tokens output: "
))
data
=
json
.
dumps
({
"sentences"
:
[
sentence
],
"max_len"
:
max_len
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
response
=
urllib2
.
urlopen
(
req
)
resp_sentences
=
json
.
load
(
response
)
print
(
"Megatron Response: "
)
print
(
resp_sentences
[
"sentences"
][
0
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment