Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
636da044
Commit
636da044
authored
Mar 09, 2023
by
Jared Casper
Browse files
Merge branch 'maanug/gpt-text-gen' into 'main'
Fix GPT text generation See merge request ADLR/megatron-lm!528
parents
f5cf2e42
3c760180
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
17 deletions
+21
-17
README.md
README.md
+2
-2
examples/run_text_generation_server_345M.sh
examples/run_text_generation_server_345M.sh
+3
-1
tools/text_generation_cli.py
tools/text_generation_cli.py
+16
-14
No files found.
README.md
View file @
636da044
...
...
@@ -338,7 +338,7 @@ python tools/checkpoint_util.py \
--load-dir checkpoints/gpt3_tp4_pp4
\
--save-dir checkpoints/gpt3_tp2_pp2
\
--target-tensor-parallel-size 2
\
--target-pipeline-paralle-size 2
--target-pipeline-paralle
l
-size 2
</pre>
...
...
@@ -351,7 +351,7 @@ We have included a simple REST server to use for text generation in `tools/run_t
Once the server is running you can use
`tools/text_generation_cli.py`
to query it, it takes one argument which is the host the server is running on.
<pre>
tools/text_generation_cli.py localhost
tools/text_generation_cli.py localhost
:5000
</pre>
You can also use CURL or any other tools to query the server directly:
...
...
examples/run_text_generation_server_345M.sh
View file @
636da044
...
...
@@ -10,9 +10,11 @@ CHECKPOINT=<Path to checkpoint (e.g /345m)>
VOCAB_FILE
=
<Path to vocab.json
(
e.g. /gpt2-vocab.json
)>
MERGE_FILE
=
<Path to merges.txt
(
e.g. /gpt2-merges.txt
)>
export
CUDA_DEVICE_MAX_CONNECTIONS
=
1
pip
install
flask-restful
python
-m
torch.distributed.
run
$DISTRIBUTED_ARGS
tools/run_text_generation_server.py
\
torch
run
$DISTRIBUTED_ARGS
tools/run_text_generation_server.py
\
--tensor-model-parallel-size
1
\
--pipeline-model-parallel-size
1
\
--num-layers
24
\
...
...
tools/text_generation_cli.py
View file @
636da044
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
json
import
sys
import
urllib2
class
PutRequest
(
urllib2
.
Request
):
'''class to handling putting with urllib2'''
import
json
import
requests
def
get_method
(
self
,
*
args
,
**
kwargs
):
return
'PUT'
if
__name__
==
"__main__"
:
url
=
sys
.
argv
[
1
]
url
=
'http://'
+
url
+
'/api'
headers
=
{
'Content-Type'
:
'application/json'
}
while
True
:
sentence
=
raw_input
(
"Enter prompt: "
)
tokens_to_generate
=
int
(
input
(
"Enter number of tokens to generate: "
))
data
=
json
.
dumps
({
"prompts"
:
[
sentence
],
"tokens_to_generate"
:
tokens_to_generate
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
response
=
urllib2
.
urlopen
(
req
)
resp_sentences
=
json
.
load
(
response
)
sentence
=
input
(
"Enter prompt: "
)
tokens_to_generate
=
int
(
eval
(
input
(
"Enter number of tokens to generate: "
)))
data
=
{
"prompts"
:
[
sentence
],
"tokens_to_generate"
:
tokens_to_generate
}
response
=
requests
.
put
(
url
,
data
=
json
.
dumps
(
data
),
headers
=
headers
)
if
response
.
status_code
!=
200
:
print
(
f
"Error
{
response
.
status_code
}
:
{
response
.
json
()[
'message'
]
}
"
)
else
:
print
(
"Megatron Response: "
)
print
(
resp
_sentences
[
"
text
"
][
0
])
print
(
resp
onse
.
json
()[
'
text
'
][
0
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment