Commit 811183f0 authored by rprenger's avatar rprenger
Browse files

Got it working on the full big model

parent ddd36145
......@@ -54,12 +54,13 @@ class MegatronGenerate(Resource):
return "all_probs must be a boolean value"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits = generate(self.model, sentences, max_len, all_probs)
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, max_len, all_probs)
if all_probs:
return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
"logits": output_logits,
"all_logits": full_logits})
"all_logits": full_logits,
"tokens": tokens})
return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
......
......@@ -121,7 +121,7 @@ def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda"))
input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.device("cuda"))
torch.distributed.broadcast(input_info_tensor, 0)
batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item()
......@@ -166,9 +166,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
args = get_args()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size(), dtype=torch.float32, device=torch.device("cuda"))
full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(full_logits, src, group)
if tokens is not None:
......@@ -193,8 +194,9 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
tokenizer = get_tokenizer()
resp_sentences = []
resp_sentences_seg = []
for i in range(decode_tokens.size(0)):
decode_token = decode_tokens[i,:].cpu().numpy().tolist()
decode_tokens = decode_tokens.cpu().numpy().tolist()
for decode_token in decode_tokens:
resp_sentences.append(tokenizer.detokenize(decode_token))
words = []
for token in decode_token:
......@@ -208,8 +210,8 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
full_logits = full_logits.cpu().numpy().tolist()
end = time.time()
print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True)
return resp_sentences, resp_sentences_seg, output_logits, full_logits
print(str(b)+","+str(c)+","+str(len(decode_tokens[0]))+","+str(end-start), flush=True)
return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment