text_generation_cli.py 767 Bytes
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Ryan Prenger's avatar
Ryan Prenger committed
2
import sys
Maanu Grover's avatar
Maanu Grover committed
3
4
import json
import requests
Ryan Prenger's avatar
Ryan Prenger committed
5
6
7
8


if __name__ == "__main__":
    url = sys.argv[1]
Maanu Grover's avatar
Maanu Grover committed
9
10
11
    url = 'http://' + url + '/api'
    headers = {'Content-Type': 'application/json'}

Ryan Prenger's avatar
Ryan Prenger committed
12
    while True:
Maanu Grover's avatar
Maanu Grover committed
13
14
15
16
17
18
19
20
21
22
23
        sentence = input("Enter prompt: ")
        tokens_to_generate = int(eval(input("Enter number of tokens to generate: ")))

        data = {"prompts": [sentence], "tokens_to_generate": tokens_to_generate}
        response = requests.put(url, data=json.dumps(data), headers=headers)

        if response.status_code != 200:
            print(f"Error {response.status_code}: {response.json()['message']}")
        else:
            print("Megatron Response: ")
            print(response.json()['text'][0])