Commit 76572db3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.9.2-dev

parents 864c718a f3e13c54
...@@ -211,8 +211,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -211,8 +211,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl") "FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( if self.kv_cache_dtype != "fp8":
"FlashMLA with FP8 KV cache not yet supported") raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode( def _forward_decode(
self, self,
...@@ -220,6 +221,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -220,6 +221,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
...@@ -239,6 +242,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -239,6 +242,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=decode_meta.decode_num_splits, num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
) )
return self._v_up_proj(o) return self._v_up_proj(o)
...@@ -1397,6 +1397,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1397,6 +1397,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[num_prefill_tokens:] = self._forward_decode( output[num_prefill_tokens:] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype)
return output return output
\ No newline at end of file
...@@ -75,6 +75,8 @@ def flash_mla_with_kvcache( ...@@ -75,6 +75,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor, num_splits: torch.Tensor,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
causal: bool = False, causal: bool = False,
k_scale = None,
kv_cache_dtype = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Arguments: Arguments:
...@@ -97,6 +99,22 @@ def flash_mla_with_kvcache( ...@@ -97,6 +99,22 @@ def flash_mla_with_kvcache(
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5) softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm(): if current_platform.is_rocm():
if kv_cache_dtype == "fp8":
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
"fp8_e4m3",
)
return out, softmax_lse
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, q,
k_cache, k_cache,
......
...@@ -11,6 +11,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator ...@@ -11,6 +11,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.forward_context import get_profilling
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -169,7 +170,7 @@ def _support_torch_compile( ...@@ -169,7 +170,7 @@ def _support_torch_compile(
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't # e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside. # need to compile the model inside.
if self.do_not_compile or torch.compiler.is_compiling(): if self.do_not_compile or torch.compiler.is_compiling() or get_profilling():
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
# the first compilation needs to have dynamic shapes marked # the first compilation needs to have dynamic shapes marked
......
...@@ -1087,7 +1087,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1087,7 +1087,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13": "VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "True").lower() in lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")), ("true", "1")),
} }
...@@ -1162,4 +1162,4 @@ def compute_hash() -> str: ...@@ -1162,4 +1162,4 @@ def compute_hash() -> str:
hash_str = hashlib.md5(str(factors).encode(), hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest() usedforsecurity=False).hexdigest()
return hash_str return hash_str
\ No newline at end of file
...@@ -196,3 +196,16 @@ def set_forward_context( ...@@ -196,3 +196,16 @@ def set_forward_context(
_forward_context = prev_context _forward_context = prev_context
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
set_tbo_forward_context(_forward_context) set_tbo_forward_context(_forward_context)
_profiling: bool = False
@contextmanager
def set_profilling(profiling):
global _profiling
_profiling = profiling
def get_profilling() -> bool:
global _profiling
return _profiling
\ No newline at end of file
...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, maybe_warn_marlin_atomic_add) marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
def get_scalar_type(num_bits: int, has_zp: bool): def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp: if has_zp:
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
...@@ -104,8 +104,8 @@ def fused_marlin_moe( ...@@ -104,8 +104,8 @@ def fused_marlin_moe(
topk = topk_ids.shape[1] # 8 topk = topk_ids.shape[1] # 8
#暂时固定为16384 #暂时固定为16384
CHUNK_SIZE = 16384 #CHUNK_SIZE = 16384
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
if workspace is None: if workspace is None:
...@@ -120,18 +120,21 @@ def fused_marlin_moe( ...@@ -120,18 +120,21 @@ def fused_marlin_moe(
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N), (M * topk, N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache13 = torch.empty(
(M * topk_ids.shape[1] * max(2 * N, K), ),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] if envs.VLLM_USE_GLOBAL_CACHE13:
intermediate_cache13 = get_moe_cache(topk, N, K, device=hidden_states.device, dtype=hidden_states.dtype)
else:
intermediate_cache13 = torch.empty(
(M * topk * max(2 * N, K), ),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[:M * topk * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] intermediate_cache3 = intermediate_cache13[:M * topk * K]
intermediate_cache3 = intermediate_cache3.view(-1, K) intermediate_cache3 = intermediate_cache3.view(-1, K)
use_atomic_add = hidden_states.dtype == torch.half or \ use_atomic_add = hidden_states.dtype == torch.half or \
......
...@@ -58,6 +58,11 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -58,6 +58,11 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, self.eh_proj = nn.Linear(config.hidden_size * 2,
...@@ -75,6 +80,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -75,6 +80,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0, spec_step_index: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP # masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0 inputs_embeds[positions == 0] = 0
...@@ -111,10 +118,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -111,10 +118,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
for idx in range(self.mtp_start_layer_idx, for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers) self.mtp_start_layer_idx + self.num_mtp_layers)
}) })
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
def forward( def forward(
...@@ -125,8 +129,6 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -125,8 +129,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = (spec_step_idx % self.num_mtp_layers) current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids, input_ids,
...@@ -308,25 +310,353 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -308,25 +310,353 @@ class DeepSeekMTP(nn.Module, SupportsPP):
""" """
Rewrite the weight name to match the format of the original model. Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
""" """
spec_layer_weight_names = [ spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
] ]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names: for weight_name in spec_layer_weight_names:
if weight_name in name: if weight_name in name:
spec_layer_weight = True spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break break
if not spec_layer_weight: if not spec_layer_weight:
# treat rest weights as weights for transformer layer block # treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.", name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.") f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# treat shared weights as top level weights
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name return name
# # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# import os
# import re
# from collections.abc import Iterable
# from typing import Iterable, Optional
# import torch
# import torch.nn as nn
# from transformers import PretrainedConfig
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
# from vllm.model_executor.layers.fused_moe import FusedMoE
# from vllm.model_executor.layers.layernorm import RMSNorm
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
# from vllm.model_executor.layers.quantization import QuantizationConfig
# from vllm.model_executor.layers.vocab_parallel_embedding import (
# ParallelLMHead, VocabParallelEmbedding)
# from vllm.model_executor.model_loader.weight_utils import default_weight_loader
# from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors
# from vllm.compilation.decorators import support_torch_compile
# from .deepseek_v2 import (DeepseekV2DecoderLayer,
# get_spec_layer_idx_from_weight_name)
# from .interfaces import SupportsPP
# from .utils import maybe_prefix
# from vllm import _custom_ops as ops
# from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
# class SharedHead(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.head = ParallelLMHead(config.vocab_size,
# config.hidden_size,
# quant_config=quant_config)
# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# return self.norm(hidden_states)
# class DeepSeekMultiTokenPredictorLayer(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# prefix: str,
# model_config: ModelConfig,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.eh_proj = nn.Linear(config.hidden_size * 2,
# config.hidden_size,
# bias=False)
# self.shared_head = SharedHead(config=config, quant_config=quant_config)
# self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
# cache_config, quant_config)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_index: int = 0,
# ) -> torch.Tensor:
# assert inputs_embeds is not None
# # masking inputs at position 0, as not needed by MTP
# inputs_embeds[positions == 0] = 0
# inputs_embeds = self.enorm(inputs_embeds)
# previous_hidden_states = self.hnorm(previous_hidden_states)
# hidden_states = self.eh_proj(
# torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
# hidden_states, residual = self.mtp_block(positions=positions,
# hidden_states=hidden_states,
# residual=None)
# hidden_states = residual + hidden_states
# return hidden_states
# class DeepSeekMultiTokenPredictor(nn.Module):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# config = vllm_config.model_config.hf_config
# self.mtp_start_layer_idx = config.num_hidden_layers
# self.num_mtp_layers = config.num_nextn_predict_layers
# # to map the exact layer index from weights
# self.layers = torch.nn.ModuleDict({
# str(idx):
# DeepSeekMultiTokenPredictorLayer(
# config,
# f"{prefix}.layers.{idx}",
# model_config=vllm_config.model_config,
# cache_config=vllm_config.cache_config,
# quant_config=vllm_config.quant_config,
# )
# for idx in range(self.mtp_start_layer_idx,
# self.mtp_start_layer_idx + self.num_mtp_layers)
# })
# self.embed_tokens = VocabParallelEmbedding(
# config.vocab_size,
# config.hidden_size,
# )
# self.logits_processor = LogitsProcessor(config.vocab_size)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
# input_ids,
# positions,
# previous_hidden_states,
# inputs_embeds,
# current_step_idx,
# )
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# mtp_layer = self.layers[str(self.mtp_start_layer_idx +
# current_step_idx)]
# logits = self.logits_processor(mtp_layer.shared_head.head,
# mtp_layer.shared_head(hidden_states),
# sampling_metadata)
# return logits
# @support_torch_compile
# class DeepSeekMTP(nn.Module, SupportsPP):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# self.config = vllm_config.model_config.hf_config
# quant_config = vllm_config.quant_config
# self.quant_method = None
# if quant_config is not None:
# self.quant_method = quant_config.get_name()
# os.environ['LLAMA_NN'] = '0'
# os.environ['LM_NN'] = '0'
# # The AWQ layer of MTP uses BlockInt8W8A8.
# if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
# vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
# self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
# prefix=maybe_prefix(
# prefix, "model"))
# self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# intermediate_tensors: Optional[IntermediateTensors] = None,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# hidden_states = self.model(input_ids, positions,
# previous_hidden_states, inputs_embeds,
# spec_step_idx)
# return hidden_states
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> Optional[torch.Tensor]:
# return self.model.compute_logits(hidden_states, sampling_metadata,
# spec_step_idx)
# def load_weights(self, weights: Iterable[tuple[str,
# torch.Tensor]]) -> set[str]:
# stacked_params_mapping = [
# ("gate_up_proj", "gate_proj", 0),
# ("gate_up_proj", "up_proj", 1),
# ]
# expert_params_mapping = FusedMoE.make_expert_params_mapping(
# ckpt_gate_proj_name="gate_proj",
# ckpt_down_proj_name="down_proj",
# ckpt_up_proj_name="up_proj",
# num_experts=self.config.n_routed_experts)
# params_dict = dict(self.named_parameters())
# loaded_params: set[str] = set()
# for name, loaded_weight in weights:
# if "rotary_emb.inv_freq" in name:
# continue
# spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
# if spec_layer is None:
# continue
# name = self._rewrite_spec_layer_name(spec_layer, name)
# for (param_name, weight_name, shard_id) in stacked_params_mapping:
# # Skip non-stacked layers and experts (experts handled below).
# if weight_name not in name:
# continue
# # We have mlp.experts[0].gate_proj in the checkpoint.
# # Since we handle the experts below in expert_params_mapping,
# # we need to skip here BEFORE we update the name, otherwise
# # name will be updated to mlp.experts[0].gate_up_proj, which
# # will then be updated below in expert_params_mapping
# # for mlp.experts[0].gate_gate_up_proj, which breaks load.
# if (("mlp.experts." in name) and name not in params_dict):
# continue
# name = name.replace(weight_name, param_name)
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param, loaded_weight, shard_id)
# break
# else:
# for mapping in expert_params_mapping:
# param_name, weight_name, expert_id, shard_id = mapping
# if weight_name not in name:
# continue
# name = name.replace(weight_name, param_name)
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param,
# loaded_weight,
# name,
# shard_id=shard_id,
# expert_id=expert_id)
# break
# else:
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# # According to DeepSeek-V3 Technical Report, MTP modules
# # shares embedding layer. We only load the first weights.
# if (spec_layer != self.model.mtp_start_layer_idx
# and ".layers" not in name):
# continue
# param = params_dict[name]
# weight_loader = getattr(param, "weight_loader",
# default_weight_loader)
# weight_loader(param, loaded_weight)
# loaded_params.add(name)
# if self.use_llama_nn and self.quant_method is None:
# lay_key_words = [
# "self_attn.eh_proj.weight",
# "self_attn.q_proj.weight",
# "self_attn.q_a_proj.weight",
# "self_attn.q_b_proj.weight",
# "self_attn.kv_a_proj_with_mqa.weight",
# "self_attn.kv_b_proj.weight",
# "self_attn.o_proj.weight",
# "mlp.gate_up_proj.weight",
# "mlp.down_proj.weight",
# "mlp.gate.weight",
# "shared_experts.gate_up_proj.weight",
# "shared_experts.down_proj.weight",
# "shared_head.head.weight",
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# _weight = torch.zeros_like(weight.data)
# ori_shape =_weight.shape
# ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
# weight.data.copy_(_weight)
# weight.data=weight.data.reshape(ori_shape[1],-1)
# return loaded_params
# def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
# """
# Rewrite the weight name to match the format of the original model.
# Add .mtp_block for modules in transformer layer block for spec layer
# and rename shared layer weights to be top level.
# """
# spec_layer_weight_names = [
# "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
# ]
# shared_weight_names = ["embed_tokens"]
# spec_layer_weight = False
# shared_weight = False
# for weight_name in spec_layer_weight_names:
# if weight_name in name:
# spec_layer_weight = True
# if weight_name in shared_weight_names:
# shared_weight = True
# break
# if not spec_layer_weight:
# # treat rest weights as weights for transformer layer block
# name = name.replace(f"model.layers.{spec_layer}.",
# f"model.layers.{spec_layer}.mtp_block.")
# elif shared_weight:
# # treat shared weights as top level weights
# name = name.replace(f"model.layers.{spec_layer}.", "model.")
# return name
...@@ -647,10 +647,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -647,10 +647,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
repeats = torch.from_numpy(query_lens).pin_memory().to( repeats = torch.from_numpy(query_lens).pin_memory().to(
block_table_tensor.device, non_blocking=True).contiguous() block_table_tensor.device, non_blocking=True).contiguous()
decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...], if envs.VLLM_ZERO_OVERHEAD:
repeats, dim=0).contiguous() decode_block_table_tensor = torch.empty((self._num_decode_tokens, block_table_tensor.shape[1]),
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous() device=block_table_tensor.device)
arange_np = np.arange(self._num_decodes)
indices_np = np.repeat(arange_np, query_lens)
indices = torch.from_numpy(indices_np).pin_memory().to(
block_table_tensor.device, non_blocking=True)
decode_block_table_tensor = block_table_tensor[indices].contiguous()
decode_seq_lens = seq_lens[indices].contiguous()
else:
decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...],
repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to( seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
seq_lens.device, non_blocking=True).contiguous() seq_lens.device, non_blocking=True).contiguous()
decode_seq_lens = decode_seq_lens - seq_lens_minus decode_seq_lens = decode_seq_lens - seq_lens_minus
...@@ -1086,6 +1098,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1086,6 +1098,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[:num_decode_tokens] = self._forward_decode( output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype)
return output_padded return output_padded
\ No newline at end of file
...@@ -148,8 +148,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -148,8 +148,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl") "FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( if self.kv_cache_dtype != "fp8":
"FlashMLA V1 with FP8 KV cache not yet supported") raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode( def _forward_decode(
self, self,
...@@ -157,6 +158,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -157,6 +158,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
...@@ -175,6 +178,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -175,6 +178,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=attn_metadata.decode.num_splits, num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
) )
return self._v_up_proj(o) return self._v_up_proj(o)
...@@ -29,7 +29,7 @@ from vllm.distributed.parallel_state import ( ...@@ -29,7 +29,7 @@ from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, is_global_first_rank, get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model) prepare_communication_buffer_for_model)
from vllm.forward_context import (DPMetadata, get_forward_context, from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context) set_forward_context, set_profilling)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
...@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch ...@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.zero_overhead.v1.gpu_model_runner import execute_model_sampled, zero_prepare_inputs
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
...@@ -955,15 +954,25 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -955,15 +954,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# [0, 1, 2, 5, 6, 9] # [0, 1, 2, 5, 6, 9]
target_logits_indices += arange target_logits_indices += arange
# TODO: Optimize the CPU -> GPU copy. if envs.VLLM_ZERO_OVERHEAD:
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
self.device, non_blocking=True) self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device, logits_indices = torch.from_numpy(logits_indices).pin_memory().to(self.device,
non_blocking=True) non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to( target_logits_indices = torch.from_numpy(target_logits_indices).pin_memory().to(
self.device, non_blocking=True) self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to(
self.device, non_blocking=True) self.device, non_blocking=True)
else:
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True)
# Compute the draft token ids. # Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208] # draft_token_indices: [ 1, 2, 3, 105, 106, 208]
...@@ -1364,8 +1373,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1364,8 +1373,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ZERO_OVERHEAD:
zero_prepare_inputs(self, scheduler_output, input_ids)
if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph: if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph:
model_output, finished_sending, finished_recving = \ model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens, tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
...@@ -1507,21 +1514,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1507,21 +1514,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
if envs.VLLM_ZERO_OVERHEAD:
return execute_model_sampled(self, max_gen_len, sampled_token_ids,
discard_sampled_tokens_req_indices, scheduler_output,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
logprobs_lists,
prompt_logprobs_dict,
finished_sending,
finished_recving,
num_nans_in_logits)
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist() valid_sampled_token_ids = sampled_token_ids.tolist()
...@@ -2095,7 +2087,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2095,7 +2087,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
hidden_states = outputs hidden_states = outputs
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle() and not is_profile:
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens, attn_metadata) self.drafter.dummy_run(num_tokens, attn_metadata)
...@@ -2230,6 +2222,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2230,6 +2222,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return pooler_output return pooler_output
def profile_run(self) -> None: def profile_run(self) -> None:
# set profiling flag to avoid torch compile
set_profilling(True)
self._sync_device()
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them. # TODO: handle encoder-decoder models once we support them.
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
...@@ -2313,6 +2309,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2313,6 +2309,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
del hidden_states, output del hidden_states, output
self.encoder_cache.clear() self.encoder_cache.clear()
gc.collect() gc.collect()
set_profilling(False)
def capture_model(self) -> None: def capture_model(self) -> None:
if not self.use_cuda_graph: if not self.use_cuda_graph:
......
...@@ -29,6 +29,7 @@ from vllm.v1.utils import report_usage_stats ...@@ -29,6 +29,7 @@ from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -163,8 +164,13 @@ class Worker(WorkerBase): ...@@ -163,8 +164,13 @@ class Worker(WorkerBase):
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
# Construct the model runner # Construct the model runner
self.model_runner: GPUModelRunner = GPUModelRunner( if envs.VLLM_ZERO_OVERHEAD:
self.vllm_config, self.device) logger.info('use zero overhead model_runner')
self.model_runner: GPUModelRunner = V1ZeroModelRunner(
self.vllm_config, self.device)
else:
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
if self.rank == 0: if self.rank == 0:
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
......
...@@ -14,11 +14,15 @@ requsets_valid_token_len = {} ...@@ -14,11 +14,15 @@ requsets_valid_token_len = {}
def check_stop(request: Request, def check_stop(request: Request,
max_model_len: int, max_model_len: int,
pooler_output: Optional[torch.Tensor] = None) -> bool: pooler_output: Optional[torch.Tensor] = None,
if request.request_id not in requsets_valid_token_len: use_valid_token_len:bool = False) -> bool:
requsets_valid_token_len[request.request_id] = 0 if use_valid_token_len:
return False if request.request_id not in requsets_valid_token_len:
valid_output_len = requsets_valid_token_len[request.request_id] requsets_valid_token_len[request.request_id] = 0
return False
valid_output_len = requsets_valid_token_len[request.request_id]
else:
valid_output_len = request.num_output_tokens
valid_num_tokens = request.num_prompt_tokens + valid_output_len valid_num_tokens = request.num_prompt_tokens + valid_output_len
if (valid_num_tokens >= max_model_len if (valid_num_tokens >= max_model_len
or valid_output_len >= request.max_tokens): or valid_output_len >= request.max_tokens):
...@@ -62,110 +66,121 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -62,110 +66,121 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_decoding_stats: Optional[SpecDecodingStats] = None spec_decoding_stats: Optional[SpecDecodingStats] = None
# fix last model out in zero overhead # fix last model out in zero overhead
for req_idx, req_id in enumerate(model_runner_output.fix_req_ids): if model_runner_output.fix_req_ids is not None:
if req_id not in scheduler.requests: for req_idx, req_id in enumerate(model_runner_output.fix_req_ids):
continue if req_id not in scheduler.requests:
request = scheduler.requests[req_id] continue
generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx] request = scheduler.requests[req_id]
if req_id not in requsets_valid_token_len: generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx]
requsets_valid_token_len[req_id] = 0 if req_id not in requsets_valid_token_len:
valid_output_len = requsets_valid_token_len[req_id] requsets_valid_token_len[req_id] = 0
fix_offset = valid_output_len - request.num_output_tokens valid_output_len = requsets_valid_token_len[req_id]
if isinstance(generated_token_ids, int): fix_offset = valid_output_len - request.num_output_tokens
request._output_token_ids[fix_offset] = generated_token_ids if isinstance(generated_token_ids, int):
request._all_token_ids[fix_offset] = generated_token_ids request._output_token_ids[fix_offset] = generated_token_ids
requsets_valid_token_len[req_id] += 1 request._all_token_ids[fix_offset] = generated_token_ids
else: requsets_valid_token_len[req_id] += 1
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0:
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else: else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids if valid_output_end == 0:
requsets_valid_token_len[req_id] += len(generated_token_ids) request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
requsets_valid_token_len[req_id] += len(generated_token_ids)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
stopped = False # Check for stop and update request state.
new_logprobs = None # This must be called before we make the EngineCoreOutput.
new_token_ids = generated_token_ids for num_new, output_token_id in enumerate(new_token_ids, 1):
kv_transfer_params = None stopped = check_stop(request, scheduler.max_model_len, True)
if stopped:
# Check for stop and update request state. kv_transfer_params = scheduler._free_request(request)
# This must be called before we make the EngineCoreOutput. del new_token_ids[num_new:] # Trim new tokens if needed.
for num_new, output_token_id in enumerate(new_token_ids, 1): break
stopped = check_stop(request, scheduler.max_model_len)
if stopped: pooler_output = None
kv_transfer_params = scheduler._free_request(request) if pooler_outputs:
del new_token_ids[num_new:] # Trim new tokens if needed. pooler_output = pooler_outputs[req_index]
break stopped = check_stop(request, scheduler.max_model_len,
pooler_output, True)
pooler_output = None if stopped:
if pooler_outputs: kv_transfer_params = scheduler._free_request(request)
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len, # Extract sample logprobs if needed.
pooler_output) if request.sampling_params is not None \
if stopped: and request.sampling_params.logprobs is not None and logprobs:
kv_transfer_params = scheduler._free_request(request) # NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
# Extract sample logprobs if needed. new_logprobs = logprobs.slice(req_index, req_index + 1)
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs: if new_token_ids and scheduler.structured_output_manager.should_advance(
# NOTE: once we support N tokens per step (spec decode), request):
# the outer lists can be of length > 1. # NOTE: structured_output_request
new_logprobs = logprobs.slice(req_index, req_index + 1) # should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
if new_token_ids and scheduler.structured_output_manager.should_advance( request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request): req_id, new_token_ids)
# NOTE: structured_output_request
# should not be None if use_structured_output, we have # spec_token_ids comes from the model runner output
# check above, so safe to ignore type warning if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] request.num_nans_in_logits = num_nans_in_logits[req_id]
req_id, new_token_ids)
# Get prompt logprobs for this request.
# spec_token_ids comes from the model runner output prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if num_nans_in_logits is not None and req_id in num_nans_in_logits: if new_token_ids or pooler_output is not None \
request.num_nans_in_logits = num_nans_in_logits[req_id] or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if scheduler.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids[req_index])
else: else:
request.spec_token_ids = spec_token_ids[req_index] # Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
# fix last model out in zero overhead
if model_runner_output.fix_draft_req_ids is not None:
for req_idx, req_id in enumerate(model_runner_output.fix_draft_req_ids):
if req_id not in scheduler.requests:
continue
request = scheduler.requests[req_id]
# Add newly generated spec token ids to the request.
if model_runner_output.fix_draft_tokens_ids is not None:
if scheduler.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
model_runner_output.fix_draft_tokens_ids[req_idx])
else:
request.spec_token_ids = model_runner_output.fix_draft_tokens_ids[req_idx]
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid # loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop. # expensive operations inside the loop.
for request in scheduler.running: for request in scheduler.running:
if request.is_finished():
if req_id in requsets_valid_token_len:
requsets_valid_token_len.pop(req_id)
continue
req_id = request.request_id req_id = request.request_id
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0: if num_tokens_scheduled == 0:
...@@ -212,19 +227,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -212,19 +227,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state. # Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, scheduler.max_model_len)
# if stopped: if model_runner_output.is_output_valid:
# kv_transfer_params = scheduler._free_request(request) stopped = check_stop(request, scheduler.max_model_len,
# del new_token_ids[num_new:] # Trim new tokens if needed. False)
# break if stopped:
kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
pooler_output = None pooler_output = None
if pooler_outputs: if pooler_outputs:
pooler_output = pooler_outputs[req_index] if model_runner_output.is_output_valid:
stopped = check_stop(request, scheduler.max_model_len, pooler_output = pooler_outputs[req_index]
pooler_output) stopped = check_stop(request, scheduler.max_model_len,
# if stopped: pooler_output,
# kv_transfer_params = scheduler._free_request(request) False)
if stopped:
kv_transfer_params = scheduler._free_request(request)
# Extract sample logprobs if needed. # Extract sample logprobs if needed.
if request.sampling_params is not None \ if request.sampling_params is not None \
...@@ -255,7 +275,30 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -255,7 +275,30 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
else: else:
request.spec_token_ids = spec_token_ids[req_index] request.spec_token_ids = spec_token_ids[req_index]
if not stopped: if model_runner_output.is_output_valid:
# # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
if stopped:
if req_id in requsets_valid_token_len:
requsets_valid_token_len.pop(req_id)
else:
new_running.append(request) new_running.append(request)
scheduler.running = new_running scheduler.running = new_running
......
This diff is collapsed.
...@@ -8,4 +8,7 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -8,4 +8,7 @@ from vllm.v1.outputs import ModelRunnerOutput
class ZeroV1ModelRunnerOutput(ModelRunnerOutput): class ZeroV1ModelRunnerOutput(ModelRunnerOutput):
# [num_reqs] # [num_reqs]
fix_req_ids: list[str] = None fix_req_ids: list[str] = None
fix_sampled_token_ids:list[list[int]] = None fix_sampled_token_ids:list[list[int]] = None
\ No newline at end of file fix_draft_req_ids:list[list[int]] = None
fix_draft_tokens_ids:list[list[int]] = None
is_output_valid:bool = True
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment