Unverified Commit a2ed85a2 authored by Abrar Shivani's avatar Abrar Shivani Committed by GitHub
Browse files

fix: Use min of max tokens or context length (#1322)

This PR modifies the mistralrs engine to ensure that the maximum output token length never exceeds the context length provided.
parent 9bf79b67
...@@ -65,6 +65,7 @@ fn best_device() -> pipeline_error::Result<Device> { ...@@ -65,6 +65,7 @@ fn best_device() -> pipeline_error::Result<Device> {
struct MistralRsEngine { struct MistralRsEngine {
mistralrs: Arc<MistralRs>, mistralrs: Arc<MistralRs>,
context_length: usize,
} }
impl MistralRsEngine { impl MistralRsEngine {
...@@ -203,6 +204,7 @@ impl MistralRsEngine { ...@@ -203,6 +204,7 @@ impl MistralRsEngine {
.with_prefix_cache_n(16); .with_prefix_cache_n(16);
let engine = MistralRsEngine { let engine = MistralRsEngine {
mistralrs: builder.build(), mistralrs: builder.build(),
context_length: max_seq_len,
}; };
// skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218 // skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
...@@ -310,12 +312,21 @@ impl ...@@ -310,12 +312,21 @@ impl
frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty), frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
presence_penalty: request.inner.presence_penalty.or(det.presence_penalty), presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks), stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
max_len: request max_len: {
.inner let requested_max_tokens = request
.max_completion_tokens .inner
.or(request.inner.max_tokens) .max_completion_tokens
.map(|m| m as usize) .or(request.inner.max_tokens)
.or(det.max_len), .map(|m| m as usize);
// Ensure max_len doesn't exceed context length
match requested_max_tokens {
Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
None => det
.max_len
.map(|len| std::cmp::min(len, self.context_length)),
}
},
logits_bias: request logits_bias: request
.inner .inner
.logit_bias .logit_bias
...@@ -499,12 +510,17 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon ...@@ -499,12 +510,17 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon
.clone() .clone()
.map(to_stop_tokens) .map(to_stop_tokens)
.or(det.stop_toks), .or(det.stop_toks),
max_len: request max_len: {
.inner let requested_max_tokens = request.inner.max_tokens.map(|m| m as usize);
.max_tokens
.or(request.inner.max_tokens) // Ensure max_len doesn't exceed context length
.map(|m| m as usize) match requested_max_tokens {
.or(det.max_len), Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
None => det
.max_len
.map(|len| std::cmp::min(len, self.context_length)),
}
},
logits_bias: request logits_bias: request
.inner .inner
.logit_bias .logit_bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment