Unverified Commit a2ed85a2 authored by Abrar Shivani's avatar Abrar Shivani Committed by GitHub
Browse files

fix: Use min of max tokens or context length (#1322)

This PR modifies the mistralrs engine to ensure that the maximum output token length never exceeds the context length provided.
parent 9bf79b67
......@@ -65,6 +65,7 @@ fn best_device() -> pipeline_error::Result<Device> {
struct MistralRsEngine {
mistralrs: Arc<MistralRs>,
context_length: usize,
}
impl MistralRsEngine {
......@@ -203,6 +204,7 @@ impl MistralRsEngine {
.with_prefix_cache_n(16);
let engine = MistralRsEngine {
mistralrs: builder.build(),
context_length: max_seq_len,
};
// skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
......@@ -310,12 +312,21 @@ impl
frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
max_len: request
max_len: {
let requested_max_tokens = request
.inner
.max_completion_tokens
.or(request.inner.max_tokens)
.map(|m| m as usize)
.or(det.max_len),
.map(|m| m as usize);
// Ensure max_len doesn't exceed context length
match requested_max_tokens {
Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
None => det
.max_len
.map(|len| std::cmp::min(len, self.context_length)),
}
},
logits_bias: request
.inner
.logit_bias
......@@ -499,12 +510,17 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon
.clone()
.map(to_stop_tokens)
.or(det.stop_toks),
max_len: request
.inner
.max_tokens
.or(request.inner.max_tokens)
.map(|m| m as usize)
.or(det.max_len),
max_len: {
let requested_max_tokens = request.inner.max_tokens.map(|m| m as usize);
// Ensure max_len doesn't exceed context length
match requested_max_tokens {
Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
None => det
.max_len
.map(|len| std::cmp::min(len, self.context_length)),
}
},
logits_bias: request
.inner
.logit_bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment