Commit f47aa770 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'add-temperature-parameter-to-server-api' into 'main'

Add temperature to the server API

See merge request ADLR/megatron-lm!325
parents a97d676b 527e07c0
...@@ -55,8 +55,15 @@ class MegatronGenerate(Resource): ...@@ -55,8 +55,15 @@ class MegatronGenerate(Resource):
if not isinstance(all_probs, bool): if not isinstance(all_probs, bool):
return "all_probs must be a boolean value" return "all_probs must be a boolean value"
temperature = args.temperature
if "temperature" in request.get_json():
temperature = request.get_json()["temperature"]
if not isinstance(temperature, float) or not \
0.0 < temperature <= 100.0:
return "temperature must be a positive float less than or equal to 100.0"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature)
if all_probs: if all_probs:
return jsonify({"sentences": resp_sentences, return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg, "segments": resp_sentences_seg,
......
...@@ -138,14 +138,15 @@ def receive_generate_info(): ...@@ -138,14 +138,15 @@ def receive_generate_info():
return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor, context_length_tensor,
attention_mask, position_ids, attention_mask, position_ids,
tokens_to_generate, tokens_to_generate,
all_probs) all_probs,
temperature=temperature)
for tokens, lengths, output_logits, full_logits in batch_token_iterator: for tokens, lengths, output_logits, full_logits in batch_token_iterator:
context_length += 1 context_length += 1
...@@ -174,16 +175,15 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ ...@@ -174,16 +175,15 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
if tokens is not None: if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits return tokens[:, :context_length], output_logits, full_logits
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0):
model.eval() model.eval()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
else: else:
context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()
output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
if output is not None: if output is not None:
decode_tokens, output_logits, full_logits = output decode_tokens, output_logits, full_logits = output
...@@ -262,7 +262,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -262,7 +262,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
def sample_sequence_batch(model, context_tokens, context_lengths, def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids, attention_mask, position_ids,
tokens_to_generate, all_probs=False, type_ids=None): tokens_to_generate, all_probs=False, type_ids=None, temperature=None):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -324,7 +324,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -324,7 +324,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
prev = torch.argmax(logits, dim=-1).view(-1) prev = torch.argmax(logits, dim=-1).view(-1)
else: else:
logits = logits.float() logits = logits.float()
logits /= args.temperature logits /= temperature
logits = top_k_logits(logits, top_k=args.top_k, logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p) top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment