Unverified Commit 4ede59a2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: speculative prefill (#6230)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarJanelle Cai <jcai18@mit.edu>
parent f4f82762
...@@ -83,6 +83,38 @@ pub enum WorkerType { ...@@ -83,6 +83,38 @@ pub enum WorkerType {
Decode, Decode,
} }
/// Configuration for reasoning/thinking token output in the mocker.
///
/// When set, the mocker wraps the first portion of each response in thinking
/// boundary tokens: `[start_token, random..., end_token, random...]`.
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct ReasoningConfig {
pub start_thinking_token_id: u32,
pub end_thinking_token_id: u32,
#[validate(range(min = 0.0, max = 1.0))]
pub thinking_ratio: f64,
}
impl ReasoningConfig {
/// Number of thinking tokens (including start/end boundaries) for a given osl.
/// Returns 0 if osl < 2 (thinking disabled). Otherwise clamps to [2, osl].
pub fn num_thinking_tokens(&self, max_output_tokens: usize) -> usize {
if max_output_tokens < 2 {
return 0;
}
let raw = (max_output_tokens as f64 * self.thinking_ratio).floor() as usize;
if raw == 0 {
return 0;
}
raw.max(2).min(max_output_tokens)
}
/// Number of response tokens after the thinking block.
pub fn num_response_tokens(&self, max_output_tokens: usize) -> usize {
max_output_tokens.saturating_sub(self.num_thinking_tokens(max_output_tokens))
}
}
/// Configuration arguments for MockVllmEngine /// Configuration arguments for MockVllmEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)]
#[builder(pattern = "owned", build_fn(public))] #[builder(pattern = "owned", build_fn(public))]
...@@ -146,6 +178,11 @@ pub struct MockEngineArgs { ...@@ -146,6 +178,11 @@ pub struct MockEngineArgs {
/// If None, bootstrap rendezvous is disabled. /// If None, bootstrap rendezvous is disabled.
#[builder(default = "None")] #[builder(default = "None")]
pub bootstrap_port: Option<u16>, pub bootstrap_port: Option<u16>,
/// Reasoning/thinking token configuration.
/// When set, the mocker wraps output in thinking boundary tokens.
#[builder(default = "None")]
pub reasoning: Option<ReasoningConfig>,
} }
impl Default for MockEngineArgs { impl Default for MockEngineArgs {
...@@ -198,6 +235,7 @@ impl MockEngineArgs { ...@@ -198,6 +235,7 @@ impl MockEngineArgs {
"planner_profile_data", "planner_profile_data",
"enable_local_indexer", "enable_local_indexer",
"bootstrap_port", "bootstrap_port",
"reasoning",
] ]
.iter() .iter()
.cloned() .cloned()
...@@ -291,6 +329,12 @@ impl MockEngineArgs { ...@@ -291,6 +329,12 @@ impl MockEngineArgs {
builder = builder.bootstrap_port(Some(port as u16)); builder = builder.bootstrap_port(Some(port as u16));
} }
if let Some(value) = extra_args.get("reasoning") {
let cfg: ReasoningConfig = serde_json::from_value(value.clone())
.map_err(|e| anyhow::anyhow!("Failed to parse reasoning config: {}", e))?;
builder = builder.reasoning(Some(cfg));
}
// Parse worker type from is_prefill and is_decode flags // Parse worker type from is_prefill and is_decode flags
let is_prefill = extra_args let is_prefill = extra_args
.get("is_prefill") .get("is_prefill")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment