Commit 855cb148 authored by 王敏's avatar 王敏
Browse files

merge dev分支代码

parents 9135afe4 fe2e2705
......@@ -58,6 +58,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
import vllm.envs as envs
FalconConfig = Union[HF_FalconConfig, RWConfig]
......@@ -393,7 +394,7 @@ class FalconModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)
......
......@@ -31,6 +31,7 @@ from typing import Optional, Union
import torch
from torch import nn
from transformers import Glm4Config
import vllm.envs as envs
class MultiModalConfigProxy:
......@@ -332,7 +333,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......
......@@ -52,6 +52,7 @@ from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
......@@ -129,6 +130,58 @@ class Qwen3Attention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward(
self,
positions: torch.Tensor,
......@@ -136,22 +189,49 @@ class Qwen3Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
if envs.VLLM_USE_FUSED_RMS_ROPE:
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
# Add qk-norm then RoPE (original path).
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......
......@@ -38,6 +38,19 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
try:
from vllm.model_executor.layers.fused_moe.router_capture import (
maybe_record_router_logits,
)
except ImportError:
def maybe_record_router_logits(
*,
layer_name: str,
router_logits: torch.Tensor,
top_k: int,
) -> None:
return None
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -111,6 +124,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self._router_top_k = int(config.num_experts_per_tok)
self._router_capture_layer_name = prefix
if self.tp_size > config.num_experts:
raise ValueError(
......@@ -140,6 +155,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if not (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()):
capture_enabled = envs.VLLM_MOE_ROUTER_CAPTURE
if capture_enabled:
maybe_record_router_logits(
layer_name=self._router_capture_layer_name,
router_logits=router_logits,
top_k=self._router_top_k,
)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
......@@ -453,7 +476,7 @@ class Qwen3MoeModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......
......@@ -37,6 +37,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
import vllm.envs as envs
class TeleChat2Model(LlamaModel):
......@@ -66,8 +67,7 @@ class TeleChat2Model(LlamaModel):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
......
......@@ -17,7 +17,7 @@ from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from vllm.utils import SUPPORT_TC
if not SUPPORT_TC:
os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
......
......@@ -1958,7 +1958,7 @@ class W8a8GetCacheJSON:
self.moe_weight_shapes=[]
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
self.cache_json_data = {}
device_name =arch_name+'_'+str(arch_cu)+'cu'
self.device_name=device_name
self.topk=1
......@@ -2060,19 +2060,27 @@ class W8a8GetCacheJSON:
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False):
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False,use_int8_w8a8:Optional[bool]=False):
if use_int4_w4a8:
if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
elif use_int8_w8a8:
if block_size is not None:
return self.triton_json_dir + f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir + f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir + f"/MOE_BLOCKFP8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir + f"/MOE_W8A8FP8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
if file_path in self.cache_json_data:
# 直接返回缓存数据,避免重复读取
return self.cache_json_data[file_path]
cache_json_file=file_path
if os.path.exists(file_path):
......@@ -2088,7 +2096,7 @@ class W8a8GetCacheJSON:
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
self.cache_json_data[file_path] = configs_dict
return configs_dict
# Adapted from: https://stackoverflow.com/a/47212782/5082708
......
......@@ -217,7 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous
from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_qkv
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -1233,21 +1233,62 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
if has_prefill:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device)
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device)
fuse_rmsnorm_rope_quant_qkv(
positions[:num_actual_toks, ...],
query_nope,
q,
q_quant,
q_scale,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
else:
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
......@@ -1259,12 +1300,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_decode:
assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
decode_q = q_quant[:num_decode_tokens]
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# todo: bmm support
decode_ql_nope = torch.bmm(q_scale, decode_q_nope, self.W_UK_T) if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA else torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
......
......@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata,
flash_mla_with_kvcache_fp8,
flash_mla_with_kvcache_fp8_with_cat,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported)
from vllm.logger import init_logger
......@@ -181,31 +182,48 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
if envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
o, _ = flash_mla_with_kvcache_fp8_with_cat(
q_nope=q_nope.unsqueeze(1),
q_pe=q_pe.unsqueeze(1),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q = q_scale,
descale_k = k_scale,
)
else:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
else:
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
......
......@@ -1051,14 +1051,15 @@ class Scheduler(SchedulerInterface):
def schedule(self) -> SchedulerOutput:
if envs.VLLM_USE_PD_SPLIT:
return self.schedule_split_pd()
else:
if self.connector is not None:
return self.schedule_default()
if self.full_cuda_graph and self.use_mla and self.num_spec_tokens > 0 :
return self.schedule_split_pd()
if self.use_mla:
if self.full_cuda_graph and self.num_spec_tokens > 0:
return self.schedule_split_pd()
else:
return self.schedule_default()
else:
return self.schedule_default()
return self.schedule_split_pd()
else:
return self.schedule_default()
def _update_after_schedule(
self,
......@@ -1101,13 +1102,14 @@ class Scheduler(SchedulerInterface):
req_id = req.request_id
req_ids.append(req_id)
num_tokens = req.num_generated_token_ids
if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids = req.all_token_ids[-num_tokens:]
token_ids = req.all_token_ids[-num_tokens:] if num_tokens > 0 else []
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
......@@ -1241,7 +1243,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1
request.num_generated_token_ids = len(generated_token_ids)
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
......@@ -1253,7 +1255,6 @@ class Scheduler(SchedulerInterface):
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
......
......@@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import (
prepare_communication_buffer_for_model,
get_tensor_model_parallel_world_size)
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context)
set_forward_context, set_profilling)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
......@@ -514,14 +514,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
num_new_tokens = len(new_token_ids)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
new_token_ids)
if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids
......@@ -541,8 +539,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id)
continue
if not is_last_rank:
req_state = self.requests[req_id]
self.input_batch.add_request(req_state)
req_index = self.input_batch.req_id_to_index.get(req_id)
else:
req_ids_to_add.append(req_id)
continue
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
......@@ -554,13 +557,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + 1
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = new_token_ids[-1]
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
if len(new_token_ids) > 0:
end_token_index = num_computed_tokens + 1
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = new_token_ids[-1]
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
if spec_token_ids:
......@@ -1598,6 +1602,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if req_state.output_token_ids:
continue
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
......@@ -1675,11 +1684,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata,
attn_metadata,
)
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
......@@ -2089,7 +2096,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection
- during profile_run
- during DP rank dummy run
- during DP rank dummy run
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
......@@ -3460,6 +3467,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if req_state.output_token_ids:
continue
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
......@@ -3481,7 +3493,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states[:num_scheduled_tokens],
scheduler_output,
)
#-----------------------------------
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment