Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
3c760180
Commit
3c760180
authored
Mar 09, 2023
by
Maanu Grover
Committed by
Jared Casper
Mar 09, 2023
Browse files
Fix GPT text generation
parent
ef59b687
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
17 deletions
+21
-17
README.md
README.md
+2
-2
examples/run_text_generation_server_345M.sh
examples/run_text_generation_server_345M.sh
+3
-1
tools/text_generation_cli.py
tools/text_generation_cli.py
+16
-14
No files found.
README.md
View file @
3c760180
...
@@ -417,7 +417,7 @@ python tools/checkpoint_util.py \
...
@@ -417,7 +417,7 @@ python tools/checkpoint_util.py \
--load-dir checkpoints/gpt3_tp4_pp4
\
--load-dir checkpoints/gpt3_tp4_pp4
\
--save-dir checkpoints/gpt3_tp2_pp2
\
--save-dir checkpoints/gpt3_tp2_pp2
\
--target-tensor-parallel-size 2
\
--target-tensor-parallel-size 2
\
--target-pipeline-paralle-size 2
--target-pipeline-paralle
l
-size 2
</pre>
</pre>
...
@@ -430,7 +430,7 @@ We have included a simple REST server to use for text generation in `tools/run_t
...
@@ -430,7 +430,7 @@ We have included a simple REST server to use for text generation in `tools/run_t
Once the server is running you can use
`tools/text_generation_cli.py`
to query it, it takes one argument which is the host the server is running on.
Once the server is running you can use
`tools/text_generation_cli.py`
to query it, it takes one argument which is the host the server is running on.
<pre>
<pre>
tools/text_generation_cli.py localhost
tools/text_generation_cli.py localhost
:5000
</pre>
</pre>
You can also use CURL or any other tools to query the server directly:
You can also use CURL or any other tools to query the server directly:
...
...
examples/run_text_generation_server_345M.sh
View file @
3c760180
...
@@ -10,9 +10,11 @@ CHECKPOINT=<Path to checkpoint (e.g /345m)>
...
@@ -10,9 +10,11 @@ CHECKPOINT=<Path to checkpoint (e.g /345m)>
VOCAB_FILE
=
<Path to vocab.json
(
e.g. /gpt2-vocab.json
)>
VOCAB_FILE
=
<Path to vocab.json
(
e.g. /gpt2-vocab.json
)>
MERGE_FILE
=
<Path to merges.txt
(
e.g. /gpt2-merges.txt
)>
MERGE_FILE
=
<Path to merges.txt
(
e.g. /gpt2-merges.txt
)>
export
CUDA_DEVICE_MAX_CONNECTIONS
=
1
pip
install
flask-restful
pip
install
flask-restful
python
-m
torch.distributed.
run
$DISTRIBUTED_ARGS
tools/run_text_generation_server.py
\
torch
run
$DISTRIBUTED_ARGS
tools/run_text_generation_server.py
\
--tensor-model-parallel-size
1
\
--tensor-model-parallel-size
1
\
--pipeline-model-parallel-size
1
\
--pipeline-model-parallel-size
1
\
--num-layers
24
\
--num-layers
24
\
...
...
tools/text_generation_cli.py
View file @
3c760180
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
json
import
sys
import
sys
import
urllib2
import
json
class
PutRequest
(
urllib2
.
Request
):
import
requests
'''class to handling putting with urllib2'''
def
get_method
(
self
,
*
args
,
**
kwargs
):
return
'PUT'
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
url
=
sys
.
argv
[
1
]
url
=
sys
.
argv
[
1
]
url
=
'http://'
+
url
+
'/api'
headers
=
{
'Content-Type'
:
'application/json'
}
while
True
:
while
True
:
sentence
=
raw_input
(
"Enter prompt: "
)
sentence
=
input
(
"Enter prompt: "
)
tokens_to_generate
=
int
(
input
(
"Enter number of tokens to generate: "
))
tokens_to_generate
=
int
(
eval
(
input
(
"Enter number of tokens to generate: "
)))
data
=
json
.
dumps
({
"prompts"
:
[
sentence
],
"tokens_to_generate"
:
tokens_to_generate
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
data
=
{
"prompts"
:
[
sentence
],
"tokens_to_generate"
:
tokens_to_generate
}
response
=
urllib2
.
urlopen
(
req
)
response
=
requests
.
put
(
url
,
data
=
json
.
dumps
(
data
),
headers
=
headers
)
resp_sentences
=
json
.
load
(
response
)
print
(
"Megatron Response: "
)
if
response
.
status_code
!=
200
:
print
(
resp_sentences
[
"text"
][
0
])
print
(
f
"Error
{
response
.
status_code
}
:
{
response
.
json
()[
'message'
]
}
"
)
else
:
print
(
"Megatron Response: "
)
print
(
response
.
json
()[
'text'
][
0
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment