Commit 7a9c4a03 authored by rprenger's avatar rprenger
Browse files

Removing bug possibilities and adding timing info

parent 29dd0a35
...@@ -61,4 +61,4 @@ class MegatronServer(object): ...@@ -61,4 +61,4 @@ class MegatronServer(object):
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])
def run(self, url): def run(self, url):
self.app.run(url, debug=False) self.app.run(url, threaded=False, debug=False)
...@@ -162,6 +162,9 @@ def synced_generate(model, context_length_tensor, context_tokens_tensor, max_len ...@@ -162,6 +162,9 @@ def synced_generate(model, context_length_tensor, context_tokens_tensor, max_len
def generate(model, sentences=None, max_len=0): def generate(model, sentences=None, max_len=0):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
c = context_length_tensor[0]
b = context_tokens_tensor.size(0)
start = time.time()
send_generate_info(context_tokens_tensor, context_length_tensor, max_len) send_generate_info(context_tokens_tensor, context_length_tensor, max_len)
else: else:
context_length_tensor, context_tokens_tensor, max_len = receive_generate_info() context_length_tensor, context_tokens_tensor, max_len = receive_generate_info()
...@@ -176,6 +179,8 @@ def generate(model, sentences=None, max_len=0): ...@@ -176,6 +179,8 @@ def generate(model, sentences=None, max_len=0):
for i in range(decode_tokens.size(0)): for i in range(decode_tokens.size(0)):
decode_token = decode_tokens[i,:].cpu().numpy().tolist() decode_token = decode_tokens[i,:].cpu().numpy().tolist()
resp_sentences.append(tokenizer.detokenize(decode_token)) resp_sentences.append(tokenizer.detokenize(decode_token))
end = time.time()
print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True)
return resp_sentences return resp_sentences
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import sys
import urllib2 import urllib2
class PutRequest(urllib2.Request): class PutRequest(urllib2.Request):
'''class to handling putting with urllib2''' '''class to handling putting with urllib2'''
...@@ -21,11 +22,12 @@ class PutRequest(urllib2.Request): ...@@ -21,11 +22,12 @@ class PutRequest(urllib2.Request):
return 'PUT' return 'PUT'
if __name__ == "__main__": if __name__ == "__main__":
url = sys.argv[1]
while True: while True:
sentence = raw_input("Enter prompt: ") sentence = raw_input("Enter prompt: ")
max_len = int(input("Enter number tokens output: ")) max_len = int(input("Enter number tokens output: "))
data = json.dumps({"sentences": [sentence], "max_len":max_len}) data = json.dumps({"sentences": [sentence], "max_len":max_len})
req = PutRequest("http://sc-sdgx2-484:5000/generate", data, {'Content-Type': 'application/json'}) req = PutRequest(url, data, {'Content-Type': 'application/json'})
response = urllib2.urlopen(req) response = urllib2.urlopen(req)
resp_sentences = json.load(response) resp_sentences = json.load(response)
print("Megatron Response: ") print("Megatron Response: ")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment