Commit 3c760180 authored by Maanu Grover's avatar Maanu Grover Committed by Jared Casper
Browse files

Fix GPT text generation

parent ef59b687
...@@ -417,7 +417,7 @@ python tools/checkpoint_util.py \ ...@@ -417,7 +417,7 @@ python tools/checkpoint_util.py \
--load-dir checkpoints/gpt3_tp4_pp4 \ --load-dir checkpoints/gpt3_tp4_pp4 \
--save-dir checkpoints/gpt3_tp2_pp2 \ --save-dir checkpoints/gpt3_tp2_pp2 \
--target-tensor-parallel-size 2 \ --target-tensor-parallel-size 2 \
--target-pipeline-paralle-size 2 --target-pipeline-parallel-size 2
</pre> </pre>
...@@ -430,7 +430,7 @@ We have included a simple REST server to use for text generation in `tools/run_t ...@@ -430,7 +430,7 @@ We have included a simple REST server to use for text generation in `tools/run_t
Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on. Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.
<pre> <pre>
tools/text_generation_cli.py localhost tools/text_generation_cli.py localhost:5000
</pre> </pre>
You can also use CURL or any other tools to query the server directly: You can also use CURL or any other tools to query the server directly:
......
...@@ -10,9 +10,11 @@ CHECKPOINT=<Path to checkpoint (e.g /345m)> ...@@ -10,9 +10,11 @@ CHECKPOINT=<Path to checkpoint (e.g /345m)>
VOCAB_FILE=<Path to vocab.json (e.g. /gpt2-vocab.json)> VOCAB_FILE=<Path to vocab.json (e.g. /gpt2-vocab.json)>
MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)> MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>
export CUDA_DEVICE_MAX_CONNECTIONS=1
pip install flask-restful pip install flask-restful
python -m torch.distributed.run $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--tensor-model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \ --pipeline-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json
import sys import sys
import urllib2 import json
class PutRequest(urllib2.Request): import requests
'''class to handling putting with urllib2'''
def get_method(self, *args, **kwargs):
return 'PUT'
if __name__ == "__main__": if __name__ == "__main__":
url = sys.argv[1] url = sys.argv[1]
url = 'http://' + url + '/api'
headers = {'Content-Type': 'application/json'}
while True: while True:
sentence = raw_input("Enter prompt: ") sentence = input("Enter prompt: ")
tokens_to_generate = int(input("Enter number of tokens to generate: ")) tokens_to_generate = int(eval(input("Enter number of tokens to generate: ")))
data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
req = PutRequest(url, data, {'Content-Type': 'application/json'}) data = {"prompts": [sentence], "tokens_to_generate": tokens_to_generate}
response = urllib2.urlopen(req) response = requests.put(url, data=json.dumps(data), headers=headers)
resp_sentences = json.load(response)
print("Megatron Response: ") if response.status_code != 200:
print(resp_sentences["text"][0]) print(f"Error {response.status_code}: {response.json()['message']}")
else:
print("Megatron Response: ")
print(response.json()['text'][0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment