Unverified Commit 527e07c0 authored by Robert Clark's avatar Robert Clark
Browse files

Add temperature to the server API



A temperature value between 0.0 and 100.0 can now be specified via the API
while running the text generation server. The value passed to
--temperature while running the text generation server is kept as the
default value for all API calls that don't include temperature, even if
different values were manually specified previously.
Signed-Off-By: default avatarRobert Clark <roclark@nvidia.com>
parent a97d676b
......@@ -55,8 +55,15 @@ class MegatronGenerate(Resource):
if not isinstance(all_probs, bool):
return "all_probs must be a boolean value"
temperature = args.temperature
if "temperature" in request.get_json():
temperature = request.get_json()["temperature"]
if not isinstance(temperature, float) or not \
0.0 < temperature <= 100.0:
return "temperature must be a positive float less than or equal to 100.0"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs)
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature)
if all_probs:
return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
......
......@@ -138,14 +138,15 @@ def receive_generate_info():
return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids,
tokens_to_generate,
all_probs)
all_probs,
temperature=temperature)
for tokens, lengths, output_logits, full_logits in batch_token_iterator:
context_length += 1
......@@ -174,16 +175,15 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False):
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0):
model.eval()
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
else:
context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()
output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
if output is not None:
decode_tokens, output_logits, full_logits = output
......@@ -262,7 +262,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
tokens_to_generate, all_probs=False, type_ids=None):
tokens_to_generate, all_probs=False, type_ids=None, temperature=None):
args = get_args()
tokenizer = get_tokenizer()
......@@ -324,7 +324,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
logits /= args.temperature
logits /= temperature
logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment