Commit b7ae685f authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'add_all_probs' into 'main'

Letting server return the log-probabilities of the context and generated text

See merge request ADLR/megatron-lm!317
parents 3860e995 d1b155c9
......@@ -47,10 +47,24 @@ class MegatronGenerate(Resource):
if max_len < 1:
return "max_len must be an integer greater than 0"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences = generate(self.model, sentences, max_len)
return jsonify({"sentences": resp_sentences})
all_probs = False
if "all_probs" in request.get_json():
all_probs = request.get_json()["all_probs"]
if not isinstance(all_probs, bool):
return "all_probs must be a boolean value"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, max_len, all_probs)
if all_probs:
return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
"logits": output_logits,
"all_logits": full_logits,
"tokens": tokens})
return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
"logits": output_logits})
def index():
return current_app.send_static_file('index.html')
......
......@@ -104,12 +104,12 @@ def tokenize_batch(sentences):
context_length_tensor = torch.cuda.LongTensor(context_lengths)
return context_tokens_tensor, context_length_tensor
def send_generate_info(context_tokens_tensor, context_length_tensor, max_len):
def send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len]
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len, all_probs]
input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.broadcast(input_info_tensor, 0)
......@@ -121,11 +121,12 @@ def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.cuda.current_device())
input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.cuda.current_device())
torch.distributed.broadcast(input_info_tensor, 0)
batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item()
max_len = input_info_tensor[2].item()
all_probs = input_info_tensor[3].item()
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
......@@ -134,40 +135,79 @@ def receive_generate_info():
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
return context_length_tensor, context_tokens_tensor, max_len
return context_length_tensor, context_tokens_tensor, max_len, all_probs
def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len):
def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs):
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids,
max_len)
for tokens, lengths in batch_token_iterator:
max_len,
all_probs)
for tokens, lengths, output_logits, full_logits in batch_token_iterator:
context_length += 1
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(full_logits, src, group)
else:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
args = get_args()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(full_logits, src, group)
if tokens is not None:
return tokens[:, :context_length]
return tokens[:, :context_length], output_logits, full_logits
def generate(model, sentences=None, max_len=0):
def generate(model, sentences=None, max_len=0, all_probs=False):
model.eval()
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
send_generate_info(context_tokens_tensor, context_length_tensor, max_len)
send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs)
else:
context_length_tensor, context_tokens_tensor, max_len = receive_generate_info()
decode_tokens = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len)
context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info()
output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs)
if output is not None:
decode_tokens, output_logits, full_logits = output
if torch.distributed.get_rank() == 0:
args = get_args()
tokenizer = get_tokenizer()
resp_sentences = []
for i in range(decode_tokens.size(0)):
decode_token = decode_tokens[i,:].cpu().numpy().tolist()
resp_sentences_seg = []
decode_tokens = decode_tokens.cpu().numpy().tolist()
for decode_token in decode_tokens:
resp_sentences.append(tokenizer.detokenize(decode_token))
return resp_sentences
words = []
for token in decode_token:
word = tokenizer.tokenizer.decoder[token]
word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace')
words.append(word)
resp_sentences_seg.append(words)
output_logits = output_logits.cpu().numpy().tolist()
if all_probs:
full_logits = full_logits.cpu().numpy().tolist()
return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
"""
......@@ -222,7 +262,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
maxlen=None, type_ids=None):
maxlen=None, all_probs=False, type_ids=None):
args = get_args()
tokenizer = get_tokenizer()
......@@ -244,6 +284,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
output_logits = None
if maxlen is None:
maxlen = args.seq_length - 1
......@@ -269,6 +311,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
output, layer_past = forward_step(model, tokens2use,
positions2use,
attention_mask,
......@@ -296,6 +339,24 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens
if output_logits is None:
output_context = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
output_logits = torch.gather(output_context, 2, indices).squeeze(2)
if all_probs:
full_logits = output_context
else:
output_context = F.log_softmax(output, 2)
indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
new_output_logits = torch.gather(output_context, 2, indices).squeeze(2)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits],1)
if all_probs:
full_logits = torch.cat([full_logits, output_context], 1)
#output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
......@@ -309,7 +370,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
yield tokens, lengths
if all_probs:
yield tokens, lengths, output_logits, full_logits
else:
yield tokens, lengths, output_logits, None
else:
if mpu.is_pipeline_first_stage():
......@@ -318,9 +382,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None
yield tokens, None, None, None
else:
yield None, None
yield None, None, None, None
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment