Commit 425eb81e authored by jujl1's avatar jujl1
Browse files

Merge branch 'v0.15.1-dev' into 'v0.15.1-dev-w4a8+pp_balance'

# Conflicts:
#   vllm/envs.py
parents 7b2122d9 358bc2c5
......@@ -47,6 +47,8 @@ from vllm.utils.flashinfer import (
should_use_flashinfer_for_blockscale_fp8_gemm,
)
from vllm.utils.torch_utils import direct_register_custom_op
from lmslim import quant_ops
from lmslim.quantize.quant_ops import BlockSize
logger = init_logger(__name__)
......@@ -357,6 +359,7 @@ class W8A8BlockFp8LinearOp:
act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
use_blaslt: bool = False,
):
self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape
......@@ -364,14 +367,13 @@ class W8A8BlockFp8LinearOp:
self.is_hopper = current_platform.is_device_capability(90)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()
# Get the correct blockscale mul and input quant operations.
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self.w8a8_blockscale_op, self.input_quant_op = (
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported
cutlass_block_fp8_supported, use_aiter_and_is_supported, use_blaslt
)
)
self.deepgemm_input_quant_op = (
......@@ -397,8 +399,14 @@ class W8A8BlockFp8LinearOp:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_shape = []
output_dtype = input.dtype
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
out_features = int(weight.shape[-1])
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
out_features = int(weight.shape[0])
if should_use_flashinfer_for_blockscale_fp8_gemm(
self.is_flashinfer_supported, output_dtype, input_2d, weight
......@@ -413,7 +421,7 @@ class W8A8BlockFp8LinearOp:
output = self._run_deepgemm(input_2d, weight, weight_scale)
else:
output = self.w8a8_blockscale_op(
input_2d, weight, weight_scale, input_scale
out_features, input_2d, weight, weight_scale, input_scale
)
if bias is not None:
......@@ -535,6 +543,37 @@ class W8A8BlockFp8LinearOp:
input_2d.dtype,
)
def _run_hipblaslt_blockwise(
self,
out_features: int,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
m, k = input_2d.shape
n = out_features
if input_scale is None:
q_input, input_scale = self.input_quant_op(input_2d)
else:
q_input = input_2d
enum_block_size = BlockSize.block_128x128
if hasattr(self, "block_size") and self.block_size[0] == 64:
enum_block_size = BlockSize.block_64x64
output = hipblaslt_w8a8_block_fp8_matmul(
A=q_input,
B=weight,
As=input_scale,
Bs=weight_scale,
block_size=enum_block_size,
output_dtype=torch.bfloat16,
)
return output
def _run_flashinfer(
self,
input_2d: torch.Tensor,
......@@ -562,6 +601,7 @@ class W8A8BlockFp8LinearOp:
self,
use_cutlass: bool,
use_aiter_and_is_supported: bool,
use_blaslt: bool,
) -> tuple[
Callable[
[
......@@ -585,6 +625,16 @@ class W8A8BlockFp8LinearOp:
)
if use_aiter_and_is_supported:
return self._run_aiter, None
if envs.VLLM_W8A8_BACKEND == 3 or use_blaslt:
return (
self._run_hipblaslt_blockwise,
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
),
)
return self._run_triton, (
QuantFP8(
False,
......@@ -1179,6 +1229,19 @@ def get_w8a8_block_fp8_configs(
)
return None
def hipblaslt_w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: BlockSize,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
assert A.shape[1] == B.shape[0]
m, k = A.shape
_, n = B.shape
_, d = quant_ops.hipblaslt_w8a8_blockwise_gemm(A, B, As, Bs, m, n, k, 'NN', output_dtype, block_size, None)
return d
def w8a8_triton_block_scaled_mm(
A: torch.Tensor,
......@@ -1597,6 +1660,10 @@ def process_fp8_weight_block_strategy(
weight=weight, weight_scale=weight_scale
)
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale = weight_scale.T.contiguous()
else:
weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale
......
......@@ -581,6 +581,31 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
)
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Update mamba_ssm_cache_dtype for Qwen3.5 models when set to 'auto'
(or not explicitly set), to the value specified in the HF config's
mamba_ssm_dtype field. Warn if the user explicitly overrides it to a
different value.
"""
cache_config = vllm_config.cache_config
hf_text_config = vllm_config.model_config.hf_text_config
mamba_ssm_dtype = getattr(hf_text_config, "mamba_ssm_dtype", None)
if cache_config.mamba_ssm_cache_dtype == "auto":
if mamba_ssm_dtype is not None:
cache_config.mamba_ssm_cache_dtype = mamba_ssm_dtype
elif (
mamba_ssm_dtype is not None
and cache_config.mamba_ssm_cache_dtype != mamba_ssm_dtype
):
logger.warning(
"Qwen3.5 model specifies mamba_ssm_dtype='%s' in its config, "
"but --mamba-ssm-cache-dtype='%s' was passed. "
"Using the user-specified value.",
mamba_ssm_dtype,
cache_config.mamba_ssm_cache_dtype,
)
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
......@@ -603,4 +628,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"FalconMambaForCausalLM": MambaModelConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
"Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
"Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
}
......@@ -49,6 +49,7 @@ from .glm4_moe import (
get_spec_layer_idx_from_weight_name,
)
from .utils import maybe_prefix
from vllm.compilation.decorators import support_torch_compile
class SharedHead(nn.Module):
......@@ -184,6 +185,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
return logits
@support_torch_compile
class Glm4MoeMTP(nn.Module, Glm4MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3_5 MTP model."""
import typing
from collections.abc import Callable, Iterable
import torch
from torch import nn
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeTextConfig,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5RMSNorm
from vllm.model_executor.models.qwen3_next import QwenNextMixtureOfExperts
from vllm.sequence import IntermediateTensors
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
_require_is_multimodal,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
_merge_multimodal_embeddings,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
logger = init_logger(__name__)
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
"hidden_states": 0,
}
)
class Qwen3_5MultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = model_config.hf_text_config
self.config = config
self.vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "mtp_num_hidden_layers", 1)
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
self.fc = ColumnParallelLinear(
self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fc",
)
self.layers = torch.nn.ModuleList(
Qwen3_5DecoderLayer(
vllm_config,
layer_type="full_attention",
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(self.num_mtp_layers)
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = Qwen3_5RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_fc_norm_embedding = Qwen3_5RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.embed_input_ids(input_ids)
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states)
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
hidden_states = self.fc(hidden_states)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
current_step_idx = spec_step_idx % self.num_mtp_layers
hidden_states, residual = self.layers[current_step_idx](
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_fused_expert_weights(
self,
name: str,
params_dict: dict,
loaded_weight: torch.Tensor,
shard_id: str,
num_experts: int,
) -> bool:
param = params_dict[name]
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
loaded_local_expert = False
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
success = weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
return_success=True,
)
if success:
loaded_local_expert = True
return loaded_local_expert
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts
if hasattr(self.config, "num_experts")
else 0,
)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
is_fused_expert = False
fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
num_experts = (
self.config.num_experts if hasattr(self.config, "num_experts") else 0
)
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name_mapped, self):
continue
if is_fused_expert:
# qwen3.5 no need to transpose
# loaded_weight = loaded_weight.transpose(-1, -2)
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
success_w1 = self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[0],
"w1",
num_experts,
)
success_w3 = self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[1],
"w3",
num_experts,
)
success = success_w1 and success_w3
else:
# down_proj
success = self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight,
shard_id,
num_experts,
)
if success:
name = name_mapped
break
else:
# Skip loading extra bias for GPTQ models.
if (
name_mapped.endswith(".bias")
or name_mapped.endswith("_bias")
) and name_mapped not in params_dict:
continue
param = params_dict[name_mapped]
weight_loader = param.weight_loader
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
logger.warning_once(
f"Parameter {name} not found in params_dict, skip loading"
)
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
"hidden_states": 0,
}
)
class Qwen3_5MTP(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_text_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
if cache_config.mamba_cache_mode == "all":
raise NotImplementedError(
"Qwen3_5MTP currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.model = Qwen3_5MultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
)
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
is_multimodal = _require_is_multimodal(is_multimodal)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
):
hidden_states = self.model(
input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
def remap_weight_names(weights):
for name, weight in weights:
if name.startswith("mtp."):
name = name.replace("mtp.", "model.")
elif any(key in name for key in ["embed_tokens", "lm_head"]):
if "embed_tokens" in name:
name = name.replace("language_model.", "")
else:
continue
yield name, weight
loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights))
class Qwen3_5MoeMTP(Qwen3_5MTP, QwenNextMixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.set_moe_parameters()
\ No newline at end of file
......@@ -46,6 +46,19 @@ from vllm.distributed import (
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
try:
from vllm.model_executor.layers.fused_moe.router_capture import (
maybe_record_router_logits,
)
except ImportError:
def maybe_record_router_logits(
*,
layer_name: str,
router_logits: torch.Tensor,
top_k: int,
) -> None:
return None
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
......@@ -152,6 +165,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self._router_top_k = int(config.num_experts_per_tok)
self._router_capture_layer_name = prefix
if self.tp_size > config.num_experts:
raise ValueError(
......@@ -235,6 +250,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if not (hasattr(torch, "compiler") and torch.compiler.is_compiling()):
capture_enabled = envs.VLLM_MOE_ROUTER_CAPTURE
if capture_enabled:
maybe_record_router_logits(
layer_name=self._router_capture_layer_name,
router_logits=router_logits,
top_k=self._router_top_k,
)
shared_out, fused_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
......@@ -341,13 +364,13 @@ class Qwen3MoeAttention(nn.Module):
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
key: torch.Tensor | None = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> None:
......@@ -371,13 +394,13 @@ class Qwen3MoeAttention(nn.Module):
# k_out:torch.Tensor,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
key: torch.Tensor | None = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> None:
......@@ -485,9 +508,9 @@ class Qwen3MoeAttention(nn.Module):
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
self.q_norm.variance_epsilon,
None,
None,
self.q_norm.variance_epsilon,
)
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2 and getattr(
self.rotary_emb, "mrope_section", None) is not None:
......
......@@ -95,6 +95,7 @@ from .utils import (
make_layers,
maybe_prefix,
)
import vllm.envs as envs
logger = init_logger(__name__)
......@@ -105,7 +106,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
......@@ -176,7 +177,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
renormalize=getattr(config, "norm_topk_prob", True),
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
......@@ -533,6 +534,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
a = a[:num_actual_tokens]
# 1. Convolution sequence transformation
if envs.VLLM_USE_NN:
conv_weights = self.conv1d.weight.squeeze(1).transpose(
0, 1).contiguous()
else:
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
......@@ -965,7 +970,7 @@ class Qwen3NextModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: Qwen3NextConfig = vllm_config.model_config.hf_config
config: Qwen3NextConfig = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
......@@ -1042,7 +1047,7 @@ class Qwen3NextModel(nn.Module):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
num_experts=getattr(self.config, "num_experts", 0),
num_redundant_experts=self.num_redundant_experts,
)
......@@ -1201,7 +1206,7 @@ class Qwen3NextForCausalLM(
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_text_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
......@@ -1265,7 +1270,7 @@ class Qwen3NextForCausalLM(
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, int], tuple[int, int]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
hf_config = vllm_config.model_config.hf_text_config
tp_size = parallel_config.tensor_parallel_size
num_spec = (
vllm_config.speculative_config.num_speculative_tokens
......
......@@ -438,6 +438,14 @@ _MULTIMODAL_MODELS = {
"qwen3_vl_moe",
"Qwen3VLMoeForConditionalGeneration",
),
"Qwen3_5ForConditionalGeneration": (
"qwen3_5",
"Qwen3_5ForConditionalGeneration",
),
"Qwen3_5MoeForConditionalGeneration": (
"qwen3_5",
"Qwen3_5MoeForConditionalGeneration",
),
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
......@@ -480,6 +488,8 @@ _SPECULATIVE_DECODING_MODELS = {
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
"Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
"Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
"Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
......
......@@ -35,6 +35,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw
from .protocol import RendererLike
import vllm.envs as envs
logger = init_logger(__name__)
......@@ -448,6 +449,12 @@ def safe_apply_chat_template(
model_config=model_config,
)
if chat_template is None:
if envs.VLLM_USE_V32_ENCODE:
from vllm.entrypoints.encoding_dsv32 import encode_messages
encode_config = dict(thinking_mode="thinking", drop_thinking=True, add_default_bos_token=True)
prompt = encode_messages(conversation, **encode_config)
return tokenizer.encode(prompt)
else:
raise ChatTemplateResolutionError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
......
......@@ -53,6 +53,10 @@ _CLASS_TO_MODULE: dict[str, str] = {
"Step3p5Config": "vllm.transformers_utils.configs.step3p5",
"Qwen3ASRConfig": "vllm.transformers_utils.configs.qwen3_asr",
"Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next",
"Qwen3_5Config": "vllm.transformers_utils.configs.qwen3_5",
"Qwen3_5TextConfig": "vllm.transformers_utils.configs.qwen3_5",
"Qwen3_5MoeConfig": "vllm.transformers_utils.configs.qwen3_5_moe",
"Qwen3_5MoeTextConfig": "vllm.transformers_utils.configs.qwen3_5_moe",
"Tarsier2Config": "vllm.transformers_utils.configs.tarsier2",
# Special case: DeepseekV3Config is from HuggingFace Transformers
"DeepseekV3Config": "transformers",
......@@ -95,6 +99,10 @@ __all__ = [
"Step3p5Config",
"Qwen3ASRConfig",
"Qwen3NextConfig",
"Qwen3_5Config",
"Qwen3_5TextConfig",
"Qwen3_5MoeConfig",
"Qwen3_5MoeTextConfig",
"Tarsier2Config",
]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3.5 model configuration"""
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
class Qwen3_5TextConfig(PretrainedConfig):
model_type = "qwen3_5_text"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
base_config_key = "text_config"
def __init__(
self,
vocab_size=248320,
hidden_size=4096,
intermediate_size=12288,
num_hidden_layers=32,
num_attention_heads=16,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_parameters=None,
attention_bias=False,
attention_dropout=0.0,
head_dim=256,
linear_conv_kernel_dim=4,
linear_key_head_dim=128,
linear_value_head_dim=128,
linear_num_key_heads=16,
linear_num_value_heads=32,
layer_types=None,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
**kwargs,
):
kwargs["ignore_keys_at_rope_validation"] = [
"mrope_section",
"mrope_interleaved",
]
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.head_dim = head_dim
self.rope_parameters = rope_parameters
kwargs.setdefault("partial_rotary_factor", 0.25)
self.layer_types = layer_types
if self.layer_types is None:
interval_pattern = kwargs.get("full_attention_interval", 4)
self.layer_types = [
"linear_attention"
if bool((i + 1) % interval_pattern)
else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types, self.num_hidden_layers)
# linear attention part
self.linear_conv_kernel_dim = linear_conv_kernel_dim
self.linear_key_head_dim = linear_key_head_dim
self.linear_value_head_dim = linear_value_head_dim
self.linear_num_key_heads = linear_num_key_heads
self.linear_num_value_heads = linear_num_value_heads
super().__init__(**kwargs)
# Set these AFTER super().__init__() because transformers v4's
# PretrainedConfig.__init__ has these as explicit params with different
# defaults (e.g. tie_word_embeddings=True) that would overwrite our values.
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.tie_word_embeddings = tie_word_embeddings
class Qwen3_5VisionConfig(PretrainedConfig):
model_type = "qwen3_5"
base_config_key = "vision_config"
def __init__(
self,
depth=27,
hidden_size=1152,
hidden_act="gelu_pytorch_tanh",
intermediate_size=4304,
num_heads=16,
in_channels=3,
patch_size=16,
spatial_merge_size=2,
temporal_patch_size=2,
out_hidden_size=3584,
num_position_embeddings=2304,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.out_hidden_size = out_hidden_size
self.num_position_embeddings = num_position_embeddings
self.initializer_range = initializer_range
class Qwen3_5Config(PretrainedConfig):
model_type = "qwen3_5"
sub_configs = {
"vision_config": Qwen3_5VisionConfig,
"text_config": Qwen3_5TextConfig,
}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=248056,
video_token_id=248057,
vision_start_token_id=248053,
vision_end_token_id=248054,
tie_word_embeddings=False,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs)
# Set after super().__init__() to avoid v4 PretrainedConfig overwrite
self.tie_word_embeddings = tie_word_embeddings
__all__ = ["Qwen3_5Config", "Qwen3_5TextConfig"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3.5-MoE model configuration"""
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
class Qwen3_5MoeTextConfig(PretrainedConfig):
model_type = "qwen3_5_moe_text"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.experts.gate_up_proj": "packed_colwise",
"layers.*.mlp.experts.down_proj": "rowwise",
"layers.*.mlp.shared_expert.gate_proj": "colwise",
"layers.*.mlp.shared_expert.up_proj": "colwise",
"layers.*.mlp.shared_expert.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
base_config_key = "text_config"
def __init__(
self,
vocab_size=248320,
hidden_size=2048,
num_hidden_layers=40,
num_attention_heads=16,
num_key_value_heads=2,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_parameters=None,
attention_bias=False,
attention_dropout=0.0,
head_dim=256,
linear_conv_kernel_dim=4,
linear_key_head_dim=128,
linear_value_head_dim=128,
linear_num_key_heads=16,
linear_num_value_heads=32,
moe_intermediate_size=512,
shared_expert_intermediate_size=512,
num_experts_per_tok=8,
num_experts=256,
output_router_logits=False,
router_aux_loss_coef=0.001,
layer_types=None,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
**kwargs,
):
kwargs["ignore_keys_at_rope_validation"] = [
"mrope_section",
"mrope_interleaved",
]
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.head_dim = head_dim
self.rope_parameters = rope_parameters
kwargs.setdefault("partial_rotary_factor", 0.25)
self.layer_types = layer_types
if self.layer_types is None:
interval_pattern = kwargs.get("full_attention_interval", 4)
self.layer_types = [
"linear_attention"
if bool((i + 1) % interval_pattern)
else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types, self.num_hidden_layers)
# linear attention part
self.linear_conv_kernel_dim = linear_conv_kernel_dim
self.linear_key_head_dim = linear_key_head_dim
self.linear_value_head_dim = linear_value_head_dim
self.linear_num_key_heads = linear_num_key_heads
self.linear_num_value_heads = linear_num_value_heads
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
super().__init__(**kwargs)
# Set these AFTER super().__init__() because transformers v4's
# PretrainedConfig.__init__ has these as explicit params with different
# defaults (e.g. tie_word_embeddings=True) that would overwrite our values.
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.tie_word_embeddings = tie_word_embeddings
class Qwen3_5MoeVisionConfig(PretrainedConfig):
model_type = "qwen3_5_moe"
base_config_key = "vision_config"
def __init__(
self,
depth=27,
hidden_size=1152,
hidden_act="gelu_pytorch_tanh",
intermediate_size=4304,
num_heads=16,
in_channels=3,
patch_size=16,
spatial_merge_size=2,
temporal_patch_size=2,
out_hidden_size=3584,
num_position_embeddings=2304,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.out_hidden_size = out_hidden_size
self.num_position_embeddings = num_position_embeddings
self.initializer_range = initializer_range
class Qwen3_5MoeConfig(PretrainedConfig):
model_type = "qwen3_5_moe"
sub_configs = {
"vision_config": Qwen3_5MoeVisionConfig,
"text_config": Qwen3_5MoeTextConfig,
}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=248056,
video_token_id=248057,
vision_start_token_id=248053,
vision_end_token_id=248054,
tie_word_embeddings=False,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs)
# Set after super().__init__() to avoid v4 PretrainedConfig overwrite
self.tie_word_embeddings = tie_word_embeddings
__all__ = ["Qwen3_5MoeConfig", "Qwen3_5MoeTextConfig"]
......@@ -371,6 +371,11 @@ class Qwen3NextMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return getattr(self.hf_text_config, "num_nextn_predict_layers", 0)
class Qwen3_5MTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "mtp_num_hidden_layers", 0)
class PanguUltraMoeMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_nextn_predict_layers", 0)
......@@ -396,6 +401,7 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"nemotron-nas": NemotronNasModelArchConfigConvertor,
"deepseek_mtp": DeepSeekMTPModelArchConfigConvertor,
"qwen3_next_mtp": Qwen3NextMTPModelArchConfigConvertor,
"qwen3_5_mtp": Qwen3_5MTPModelArchConfigConvertor,
"mimo_mtp": MimoMTPModelArchConfigConvertor,
"glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor,
"ernie_mtp": ErnieMTPModelArchConfigConvertor,
......
......@@ -61,7 +61,7 @@ class W8a8GetCacheJSON:
self.moe_weight_shapes=[]
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
self.cache_json_data = {}
device_name =arch_name+'_'+str(arch_cu)+'cu'
self.device_name=device_name
self.topk=1
......@@ -162,21 +162,30 @@ class W8a8GetCacheJSON:
def get_blockint8json_name(self,n,k,block_n,block_k):
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size: list | None = None, use_int4_w4a8: bool | None = False):
def get_moeint8json_name(self, E, N1, N2, K, TOPK,
block_size: list | None = None, use_int4_w4a8: bool | None = False,
use_int8_w8a8: bool | None = False):
if use_int4_w4a8:
if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir + f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir + f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
elif use_int8_w8a8:
if block_size is not None:
return self.triton_json_dir + f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir + f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir + f"/MOE_BLOCKFP8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir + f"/MOE_W8A8FP8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
if file_path in self.cache_json_data:
# 直接返回缓存数据,避免重复读取
return self.cache_json_data[file_path]
cache_json_file=file_path
if os.path.exists(file_path):
......@@ -192,7 +201,7 @@ class W8a8GetCacheJSON:
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
self.cache_json_data[file_path] = configs_dict
return configs_dict
......@@ -1147,6 +1147,8 @@ class SpecDecodeBaseProposer:
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"GlmOcrForConditionalGeneration",
"Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, cast
import numpy as np
......@@ -47,6 +47,12 @@ class CachedRequestState:
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
_prompt_token_ids_np: np.ndarray | None = field(
default=None,
init=False,
repr=False,
compare=False,
)
# Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0
......@@ -332,15 +338,41 @@ class InputBatch:
)
self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
else:
prompt_token_ids_np = request._prompt_token_ids_np
rebuild_prompt_cache = True
if prompt_token_ids_np is not None:
rebuild_prompt_cache = (
prompt_token_ids_np.dtype != np.int32
or prompt_token_ids_np.size != num_prompt_tokens
)
if rebuild_prompt_cache:
prompt_token_ids_np = np.asarray(request.prompt_token_ids, dtype=np.int32)
request._prompt_token_ids_np = prompt_token_ids_np
np.copyto(
self.token_ids_cpu[req_index, :num_prompt_tokens],
prompt_token_ids_np,
casting="no",
)
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
else:
output_token_ids_np = np.asarray(request.output_token_ids, dtype=np.int32)
end_idx = start_idx + output_token_ids_np.size
np.copyto(
self.token_ids_cpu[req_index, start_idx:end_idx],
output_token_ids_np,
casting="no",
)
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment