Commit e0bf5199 authored by rprenger's avatar rprenger
Browse files

Outputting log probabilities

parent 279d8320
...@@ -48,9 +48,10 @@ class MegatronGenerate(Resource): ...@@ -48,9 +48,10 @@ class MegatronGenerate(Resource):
return "max_len must be an integer greater than 0" return "max_len must be an integer greater than 0"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences = generate(self.model, sentences, max_len) resp_sentences, resp_sentences_seg, output_logits = generate(self.model, sentences, max_len)
return jsonify({"sentences": resp_sentences}) return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
"logits": output_logits})
def index(): def index():
return current_app.send_static_file('index.html') return current_app.send_static_file('index.html')
......
...@@ -144,11 +144,22 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len ...@@ -144,11 +144,22 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
context_length_tensor, context_length_tensor,
attention_mask, position_ids, attention_mask, position_ids,
max_len) max_len)
for tokens, lengths in batch_token_iterator: for tokens, lengths, output_logits in batch_token_iterator:
context_length += 1 context_length += 1
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(output_logits, src, group)
else:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(output_logits, src, group)
if tokens is not None: if tokens is not None:
return tokens[:, :context_length] return tokens[:, :context_length], output_logits
def generate(model, sentences=None, max_len=0): def generate(model, sentences=None, max_len=0):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -160,18 +171,29 @@ def generate(model, sentences=None, max_len=0): ...@@ -160,18 +171,29 @@ def generate(model, sentences=None, max_len=0):
else: else:
context_length_tensor, context_tokens_tensor, max_len = receive_generate_info() context_length_tensor, context_tokens_tensor, max_len = receive_generate_info()
decode_tokens = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len) output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len)
if output is not None:
decode_tokens, output_logits = output
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
resp_sentences = [] resp_sentences = []
resp_sentences_seg = []
for i in range(decode_tokens.size(0)): for i in range(decode_tokens.size(0)):
decode_token = decode_tokens[i,:].cpu().numpy().tolist() decode_token = decode_tokens[i,:].cpu().numpy().tolist()
resp_sentences.append(tokenizer.detokenize(decode_token)) resp_sentences.append(tokenizer.detokenize(decode_token))
words = []
for token in decode_token:
word = tokenizer.tokenizer.decoder[token]
word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace')
words.append(word)
resp_sentences_seg.append(words)
output_logits = output_logits.cpu().numpy().tolist()
end = time.time() end = time.time()
print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True) print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True)
return resp_sentences return resp_sentences, resp_sentences_seg, output_logits
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
boolean = boolean.type_as(val1) boolean = boolean.type_as(val1)
...@@ -236,6 +258,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -236,6 +258,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
batch_size = context_tokens.size(0) batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda() is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens tokens = context_tokens
output_logits = None
if maxlen is None: if maxlen is None:
maxlen = args.seq_length - 1 maxlen = args.seq_length - 1
...@@ -261,6 +285,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -261,6 +285,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, context_length - 1].view( types2use = type_ids[:, context_length - 1].view(
batch_size, -1) batch_size, -1)
output, layer_past = forward_step(model, tokens2use, output, layer_past = forward_step(model, tokens2use,
positions2use, positions2use,
attention_mask, attention_mask,
...@@ -288,6 +313,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -288,6 +313,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens = switch( new_tokens = switch(
tokens[:, context_length].view(-1), prev, started) tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens tokens[:, context_length] = new_tokens
if output_logits is None:
output_context = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, :context_length],2)
output_logits = torch.gather(output_context, 2, indices).squeeze(2)
else:
indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
new_output_logits = torch.gather(F.log_softmax(output,2), 2, indices).squeeze(2)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits],1)
#output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1)
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group) torch.distributed.broadcast(new_tokens, src, group)
...@@ -301,7 +339,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -301,7 +339,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group() group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group) torch.distributed.broadcast(done, src, group)
yield tokens, lengths yield tokens, lengths, output_logits
else: else:
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -310,9 +348,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -310,9 +348,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens = torch.empty_like(tokens[:, context_length]) new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group) torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens tokens[:, context_length] = new_tokens
yield tokens, None yield tokens, None, None
else: else:
yield None, None yield None, None, None
done = torch.cuda.ByteTensor([0]) done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment