Unverified Commit 44a2cba9 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Mistral-3-large support (#4885)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent 87112373
......@@ -12,7 +12,6 @@ class InputParamManager:
"""
if use_tokenizer:
print(f"Request: {request}")
if self.tokenizer is None:
raise ValueError("Tokenizer is not available")
......@@ -21,10 +20,9 @@ class InputParamManager:
request["messages"], tokenize=False, add_generation_prompt=True
)
elif "prompt" in request:
return request["prompt"]
return self.tokenizer.encode(request["prompt"])
elif "text" in request:
return request["text"]
return self.tokenizer.encode(request["text"])
else:
raise ValueError("No input parameter found in request")
return request.get("token_ids")
......@@ -758,7 +758,6 @@ class BaseWorkerHandler(ABC):
logger.debug(
f"Starting token generation for request {request_id} (no LoRA)"
)
gen = self.engine_client.generate(
prompt,
sampling_params,
......@@ -947,12 +946,15 @@ class DecodeWorkerHandler(BaseWorkerHandler):
async def _generate_text_mode(self, request, context, request_id):
"""Generate text using OpenAI-compatible format (text-in-text-out)."""
# Get text input using InputParamManager
input_text = self.input_param_manager.get_input_param(
input_data = self.input_param_manager.get_input_param(
request, use_tokenizer=True
)
# Build prompt for vLLM
prompt = TextPrompt(prompt=input_text)
if isinstance(input_data, list):
prompt = TokensPrompt(prompt_token_ids=input_data)
else:
prompt = TextPrompt(prompt=input_data)
# Build sampling params from OpenAI-style request
sampling_params = build_sampling_params_openai(
......@@ -1050,11 +1052,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
request_id = context.id()
logger.debug(f"Prefill Request ID: {request_id}")
if self.use_vllm_tokenizer:
# Text-in-text-out mode: use InputParamManager
async for chunk in self._generate_text_mode(request, context, request_id):
yield chunk
else:
# Token-in-token-out mode: internal protocol format
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
......@@ -1164,77 +1161,3 @@ class PrefillWorkerHandler(BaseWorkerHandler):
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
async def _generate_text_mode(self, request, context, request_id):
"""Generate prefill using OpenAI-compatible format (text-in-text-out)."""
# Get text input using InputParamManager
input_text = self.input_param_manager.get_input_param(
request, use_tokenizer=True
)
# Build prompt for vLLM
prompt = TextPrompt(prompt=input_text)
# Build sampling params from OpenAI-style request
sampling_params = build_sampling_params_openai(
request, self.default_sampling_params
)
sampling_params.detokenize = False # Prefill doesn't need detokenization
# Configure for prefill-only mode with remote decode
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
sampling_params_defaults = {
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
# Add only missing keys
for k, v in sampling_params_defaults.items():
sampling_params.extra_args["kv_transfer_params"].setdefault(k, v)
# Override for prefill: only generate 1 token
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id, is_prefill=True):
try:
gen = self.engine_client.generate(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
)
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
try:
async for res in gen:
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
token_ids = res.outputs[0].token_ids if res.outputs else []
output: Dict[str, Any] = {
"token_ids": list(token_ids),
"disaggregated_params": (
{"kv_transfer_params": res.kv_transfer_params}
if res.kv_transfer_params
else None
),
"completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res
),
}
yield output
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
......@@ -162,6 +162,11 @@ impl GenerationConfig {
}
}
/// Check if our model only has config fields for a Mistral-format model.
fn is_exclusively_mistral_model(directory: &Path) -> bool {
!directory.join("config.json").exists() && directory.join("params.json").exists()
}
#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
pub struct ModelDeploymentCard {
/// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct"
......@@ -352,7 +357,9 @@ impl ModelDeploymentCard {
.with_context(|| p.display().to_string())
}
None => {
anyhow::bail!("Blank ModelDeploymentCard does not have a tokenizer");
anyhow::bail!(
"Blank ModelDeploymentCard does not have a tokenizer. Is this a mistral model? If so, the `--use-<framework>-tokenizer` flag in the engine command is required."
);
}
}
}
......@@ -497,8 +504,23 @@ impl ModelDeploymentCard {
// If neither of those are present let the engine default it
.unwrap_or(0);
let is_mistral_model = is_exclusively_mistral_model(local_path);
let (model_info, tokenizer, gen_config, prompt_formatter) = if !is_mistral_model {
(
Some(ModelInfoType::from_disk(local_path)?),
Some(TokenizerKind::from_disk(local_path)?),
GenerationConfig::from_disk(local_path).ok(),
PromptFormatterArtifact::from_disk(local_path)?,
)
} else {
(None, None, None, None)
};
// Load chat template - either custom or from repo
let chat_template_file = if let Some(template_path) = custom_template_path {
let chat_template_file = if is_mistral_model {
None
} else if let Some(template_path) = custom_template_path {
if !template_path.exists() {
anyhow::bail!(
"Custom template file does not exist: {}",
......@@ -525,10 +547,10 @@ impl ModelDeploymentCard {
Ok(Self {
slug: Slug::from_string(&display_name),
display_name,
model_info: Some(ModelInfoType::from_disk(local_path)?),
tokenizer: Some(TokenizerKind::from_disk(local_path)?),
gen_config: GenerationConfig::from_disk(local_path).ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_disk(local_path)?,
model_info,
tokenizer,
gen_config,
prompt_formatter,
chat_template_file,
prompt_context: None, // TODO - auto-detect prompt context
context_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment