Unverified Commit 0acac5cb authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: improve temperature logic in chat (#1749)



This PR adds support for `do_sample` to chat to enable greedy sampling

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 4c698fa6
......@@ -1000,6 +1000,7 @@ async fn chat_completions(
tools,
tool_choice,
tool_prompt,
temperature,
..
} = req;
......@@ -1008,6 +1009,11 @@ async fn chat_completions(
let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default();
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
......@@ -1054,13 +1060,13 @@ async fn chat_completions(
inputs: inputs.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: req.temperature,
temperature,
repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: true,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
......
......@@ -273,7 +273,7 @@ class HeterogeneousNextTokenChooser:
else None
)
if any([x != 1.0 for x in temperature]):
if any(x != 1.0 for x in temperature):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
]
......@@ -281,15 +281,15 @@ class HeterogeneousNextTokenChooser:
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
)
if any([x != 0 for x in top_k]):
if any(x != 0 for x in top_k):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
if any([x < 1.0 for x in top_p]):
if any(x < 1.0 for x in top_p):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
if any([x < 1.0 for x in typical_p]):
if any(x < 1.0 for x in typical_p):
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment