"docs/vscode:/vscode.git/clone" did not exist on "96266f119bb93516703328f9e37ec99cce45f792"
Unverified Commit 4ede59a2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: speculative prefill (#6230)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarJanelle Cai <jcai18@mit.edu>
parent f4f82762
......@@ -83,6 +83,38 @@ pub enum WorkerType {
Decode,
}
/// Configuration for reasoning/thinking token output in the mocker.
///
/// When set, the mocker wraps the first portion of each response in thinking
/// boundary tokens: `[start_token, random..., end_token, random...]`.
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct ReasoningConfig {
pub start_thinking_token_id: u32,
pub end_thinking_token_id: u32,
#[validate(range(min = 0.0, max = 1.0))]
pub thinking_ratio: f64,
}
impl ReasoningConfig {
/// Number of thinking tokens (including start/end boundaries) for a given osl.
/// Returns 0 if osl < 2 (thinking disabled). Otherwise clamps to [2, osl].
pub fn num_thinking_tokens(&self, max_output_tokens: usize) -> usize {
if max_output_tokens < 2 {
return 0;
}
let raw = (max_output_tokens as f64 * self.thinking_ratio).floor() as usize;
if raw == 0 {
return 0;
}
raw.max(2).min(max_output_tokens)
}
/// Number of response tokens after the thinking block.
pub fn num_response_tokens(&self, max_output_tokens: usize) -> usize {
max_output_tokens.saturating_sub(self.num_thinking_tokens(max_output_tokens))
}
}
/// Configuration arguments for MockVllmEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)]
#[builder(pattern = "owned", build_fn(public))]
......@@ -146,6 +178,11 @@ pub struct MockEngineArgs {
/// If None, bootstrap rendezvous is disabled.
#[builder(default = "None")]
pub bootstrap_port: Option<u16>,
/// Reasoning/thinking token configuration.
/// When set, the mocker wraps output in thinking boundary tokens.
#[builder(default = "None")]
pub reasoning: Option<ReasoningConfig>,
}
impl Default for MockEngineArgs {
......@@ -198,6 +235,7 @@ impl MockEngineArgs {
"planner_profile_data",
"enable_local_indexer",
"bootstrap_port",
"reasoning",
]
.iter()
.cloned()
......@@ -291,6 +329,12 @@ impl MockEngineArgs {
builder = builder.bootstrap_port(Some(port as u16));
}
if let Some(value) = extra_args.get("reasoning") {
let cfg: ReasoningConfig = serde_json::from_value(value.clone())
.map_err(|e| anyhow::anyhow!("Failed to parse reasoning config: {}", e))?;
builder = builder.reasoning(Some(cfg));
}
// Parse worker type from is_prefill and is_decode flags
let is_prefill = extra_args
.get("is_prefill")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment