Commit 636da044 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'maanug/gpt-text-gen' into 'main'

Fix GPT text generation

See merge request ADLR/megatron-lm!528
parents f5cf2e42 3c760180
......@@ -338,7 +338,7 @@ python tools/checkpoint_util.py \
--load-dir checkpoints/gpt3_tp4_pp4 \
--save-dir checkpoints/gpt3_tp2_pp2 \
--target-tensor-parallel-size 2 \
--target-pipeline-paralle-size 2
--target-pipeline-parallel-size 2
</pre>
......@@ -351,7 +351,7 @@ We have included a simple REST server to use for text generation in `tools/run_t
Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.
<pre>
tools/text_generation_cli.py localhost
tools/text_generation_cli.py localhost:5000
</pre>
You can also use CURL or any other tools to query the server directly:
......
......@@ -10,9 +10,11 @@ CHECKPOINT=<Path to checkpoint (e.g /345m)>
VOCAB_FILE=<Path to vocab.json (e.g. /gpt2-vocab.json)>
MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>
export CUDA_DEVICE_MAX_CONNECTIONS=1
pip install flask-restful
python -m torch.distributed.run $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json
import sys
import urllib2
class PutRequest(urllib2.Request):
'''class to handling putting with urllib2'''
import json
import requests
def get_method(self, *args, **kwargs):
return 'PUT'
if __name__ == "__main__":
url = sys.argv[1]
url = 'http://' + url + '/api'
headers = {'Content-Type': 'application/json'}
while True:
sentence = raw_input("Enter prompt: ")
tokens_to_generate = int(input("Enter number of tokens to generate: "))
data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
req = PutRequest(url, data, {'Content-Type': 'application/json'})
response = urllib2.urlopen(req)
resp_sentences = json.load(response)
print("Megatron Response: ")
print(resp_sentences["text"][0])
sentence = input("Enter prompt: ")
tokens_to_generate = int(eval(input("Enter number of tokens to generate: ")))
data = {"prompts": [sentence], "tokens_to_generate": tokens_to_generate}
response = requests.put(url, data=json.dumps(data), headers=headers)
if response.status_code != 200:
print(f"Error {response.status_code}: {response.json()['message']}")
else:
print("Megatron Response: ")
print(response.json()['text'][0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment