Commit 811183f0 authored by rprenger's avatar rprenger
Browse files

Got it working on the full big model

parent ddd36145
...@@ -54,12 +54,13 @@ class MegatronGenerate(Resource): ...@@ -54,12 +54,13 @@ class MegatronGenerate(Resource):
return "all_probs must be a boolean value" return "all_probs must be a boolean value"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits = generate(self.model, sentences, max_len, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, max_len, all_probs)
if all_probs: if all_probs:
return jsonify({"sentences": resp_sentences, return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg, "segments": resp_sentences_seg,
"logits": output_logits, "logits": output_logits,
"all_logits": full_logits}) "all_logits": full_logits,
"tokens": tokens})
return jsonify({"sentences": resp_sentences, return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg, "segments": resp_sentences_seg,
......
...@@ -121,7 +121,7 @@ def receive_generate_info(): ...@@ -121,7 +121,7 @@ def receive_generate_info():
""" """
Needs to be synced up with send_generate_info Needs to be synced up with send_generate_info
""" """
input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda")) input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.device("cuda"))
torch.distributed.broadcast(input_info_tensor, 0) torch.distributed.broadcast(input_info_tensor, 0)
batch_size = input_info_tensor[0].item() batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item() seq_len = input_info_tensor[1].item()
...@@ -166,9 +166,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len ...@@ -166,9 +166,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
torch.distributed.broadcast(output_logits, src, group) torch.distributed.broadcast(output_logits, src, group)
if all_probs: if all_probs:
args = get_args()
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size(), dtype=torch.float32, device=torch.device("cuda")) full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(full_logits, src, group) torch.distributed.broadcast(full_logits, src, group)
if tokens is not None: if tokens is not None:
...@@ -193,8 +194,9 @@ def generate(model, sentences=None, max_len=0, all_probs=False): ...@@ -193,8 +194,9 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
resp_sentences = [] resp_sentences = []
resp_sentences_seg = [] resp_sentences_seg = []
for i in range(decode_tokens.size(0)):
decode_token = decode_tokens[i,:].cpu().numpy().tolist() decode_tokens = decode_tokens.cpu().numpy().tolist()
for decode_token in decode_tokens:
resp_sentences.append(tokenizer.detokenize(decode_token)) resp_sentences.append(tokenizer.detokenize(decode_token))
words = [] words = []
for token in decode_token: for token in decode_token:
...@@ -208,8 +210,8 @@ def generate(model, sentences=None, max_len=0, all_probs=False): ...@@ -208,8 +210,8 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
full_logits = full_logits.cpu().numpy().tolist() full_logits = full_logits.cpu().numpy().tolist()
end = time.time() end = time.time()
print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True) print(str(b)+","+str(c)+","+str(len(decode_tokens[0]))+","+str(end-start), flush=True)
return resp_sentences, resp_sentences_seg, output_logits, full_logits return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
boolean = boolean.type_as(val1) boolean = boolean.type_as(val1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment