Commit 42b06117 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' into v0.9.2-dev-ds

parents b2d14ba3 48114bb1
...@@ -179,7 +179,6 @@ void merge_attn_states(torch::Tensor& output, ...@@ -179,7 +179,6 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& suffix_lse); const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes( void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
...@@ -204,7 +203,6 @@ void convert_vertical_slash_indexes_mergehead( ...@@ -204,7 +203,6 @@ void convert_vertical_slash_indexes_mergehead(
torch::Tensor vertical_indices_count, // [N_HEADS, ] torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, int64_t context_size, torch::Tensor slash_indices_count, int64_t context_size,
int64_t block_size_M, int64_t block_size_N, bool causal); int64_t block_size_M, int64_t block_size_N, bool causal);
#endif
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon); double epsilon);
......
...@@ -229,8 +229,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -229,8 +229,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()"); " Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def( ops.def(
"convert_vertical_slash_indexes(" "convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, " " Tensor! block_count, Tensor! block_offset, "
...@@ -253,7 +251,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -253,7 +251,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" bool causal) -> ()"); " bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
&convert_vertical_slash_indexes_mergehead); &convert_vertical_slash_indexes_mergehead);
#endif
// Activation ops // Activation ops
// Activation function used in SwiGLU. // Activation function used in SwiGLU.
......
...@@ -55,7 +55,12 @@ class ReqMeta: ...@@ -55,7 +55,12 @@ class ReqMeta:
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
) )
self.parallel_config = vllm_config.parallel_config
self.model_config = vllm_config.model_config
self.total_num_hidden_layers = getattr(self.model_config.hf_text_config,
"num_hidden_layers", 0)
self.pp_size = self.parallel_config.pipeline_parallel_size
@dataclass @dataclass
class P2pNcclConnectorMetadata(KVConnectorMetadata): class P2pNcclConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] requests: list[ReqMeta]
...@@ -285,8 +290,29 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -285,8 +290,29 @@ class P2pNcclConnector(KVConnectorBase_V1):
ip, port = self.parse_request_id(request_id, True) ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank) remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_address)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4))
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4))
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i))
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
def wait_for_save(self): def wait_for_save(self):
pass pass
......
...@@ -164,11 +164,10 @@ if TYPE_CHECKING: ...@@ -164,11 +164,10 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False VLLM_USE_LIGHTOP: bool = False
VLLM_USE_TRITON_CAT: bool = False VLLM_USE_OPT_CAT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_DEEPSEEK_MOE_SUM_MUL_ADD: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1095,15 +1094,15 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1095,15 +1094,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_GLOBAL_CACHE13": "VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use lightop for deepseek-v3
"VLLM_USE_LIGHT_OP": "VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use opt cat for deepseek-v3
"VLLM_USE_TRITON_CAT": "VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt merge_aatn_states,not triton # vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
......
...@@ -44,12 +44,15 @@ from lightop import op ...@@ -44,12 +44,15 @@ from lightop import op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0') os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0')
dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1' dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
logger = init_logger(__name__) logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None moe_cache_singleton = None
def get_moe_cache(top_k_num,N,K,device,dtype): def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton global moe_cache_singleton
if moe_cache_singleton is None: if moe_cache_singleton is None:
...@@ -1295,7 +1298,9 @@ def inplace_fused_experts_fake( ...@@ -1295,7 +1298,9 @@ def inplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None:
pass pass
...@@ -1365,7 +1370,9 @@ def outplace_fused_experts_fake( ...@@ -1365,7 +1370,9 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1558,7 +1565,9 @@ def fused_experts_impl( ...@@ -1558,7 +1565,9 @@ def fused_experts_impl(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=False use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
) )
elif use_int4_w4a8 is True: elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states, return fused_experts_impl_w4a8(hidden_states=hidden_states,
...@@ -1587,7 +1596,7 @@ def fused_experts_impl( ...@@ -1587,7 +1596,7 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe= False, use_nn_moe= False,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor
) )
# #
...@@ -1760,30 +1769,28 @@ def fused_experts_impl( ...@@ -1760,30 +1769,28 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
if envs.VLLM_USE_DEEPSEEK_MOE_SUM_MUL_ADD: if envs.VLLM_USE_LIGHTOP and not dpsk_fp16_quick:
if envs.VLLM_USE_LIGHT_OP and not dpsk_fp16_quick: from lightop import op as op
if shared_output is not None: op.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
op.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], shared_output[begin_chunk_idx:end_chunk_idx], None, routed_scaling_factor)
out_hidden_states[begin_chunk_idx:end_chunk_idx], shared_output[begin_chunk_idx:end_chunk_idx], routed_scaling_factor) # else:
else: # ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), # out_hidden_states[begin_chunk_idx:end_chunk_idx])
out_hidden_states[begin_chunk_idx:end_chunk_idx]) # if shared_output is not None:
if shared_output is not None: # if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
if hidden_states.dtype != torch.float16 or dpsk_fp16_quick: # out_hidden_states[begin_chunk_idx:end_chunk_idx] = out_hidden_states[begin_chunk_idx:end_chunk_idx] * routed_scaling_factor + shared_output[begin_chunk_idx:end_chunk_idx]
out_hidden_states[begin_chunk_idx:end_chunk_idx] = out_hidden_states[begin_chunk_idx:end_chunk_idx] * routed_scaling_factor + shared_output[begin_chunk_idx:end_chunk_idx] # else:
else: # # Fix FP16 overflow
# Fix FP16 overflow # # See DeepseekV2DecoderLayer for more details.
# See DeepseekV2DecoderLayer for more details. # out_hidden_states[begin_chunk_idx:end_chunk_idx] + shared_output[begin_chunk_idx:end_chunk_idx] * (1. / routed_scaling_factor)
out_hidden_states[begin_chunk_idx:end_chunk_idx] + shared_output[begin_chunk_idx:end_chunk_idx] * (1. / routed_scaling_factor) # else:
# Deepseek theoretically wouldn't happen # if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
else: # ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
if hidden_states.dtype != torch.float16 or dpsk_fp16_quick: # out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else: else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states return out_hidden_states
...@@ -1817,7 +1824,7 @@ def fused_moe( ...@@ -1817,7 +1824,7 @@ def fused_moe(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
......
...@@ -42,7 +42,7 @@ from vllm.platforms.interface import CpuArchEnum ...@@ -42,7 +42,7 @@ from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from lightop import op
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts from .fused_batched_moe import BatchedTritonExperts
...@@ -223,6 +223,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -223,6 +223,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -375,6 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -375,6 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
...@@ -400,6 +402,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -400,6 +402,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
shared_output=shared_output,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate) use_fused_gate=use_fused_gate)
...@@ -422,6 +425,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -422,6 +425,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
...@@ -466,7 +470,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -466,7 +470,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map, expert_map=expert_map,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor
) )
def forward_cpu( def forward_cpu(
...@@ -1284,7 +1288,8 @@ class FusedMoE(torch.nn.Module): ...@@ -1284,7 +1288,8 @@ class FusedMoE(torch.nn.Module):
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
if use_fused_gate: if use_fused_gate:
if envs.VLLM_USE_LIGHT_OP: if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
topk_weights, topk_ids = op.moe_fused_gate( topk_weights, topk_ids = op.moe_fused_gate(
router_logits, router_logits,
e_score_correction_bias, e_score_correction_bias,
...@@ -1434,14 +1439,14 @@ class FusedMoE(torch.nn.Module): ...@@ -1434,14 +1439,14 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_output: torch.Tensor): shared_output: Optional[torch.Tensor] = None):
# TODO: Once the OOM issue for the TPU backend is resolved, we will # TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op. # switch to using the moe_forward custom op.
if current_platform.is_tpu(): if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, shared_output, return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name) self.layer_name, shared_output)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1521,7 +1526,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1521,7 +1526,7 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_output: torch.Tensor): shared_output: Optional[torch.Tensor] = None):
assert self.quant_method is not None assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels):
...@@ -1556,6 +1561,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1556,6 +1561,7 @@ class FusedMoE(torch.nn.Module):
expert_load_view=self.expert_load_view, expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map, logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
shared_output=shared_output,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate use_fused_gate=self.use_fused_gate
...@@ -1628,8 +1634,8 @@ class FusedMoE(torch.nn.Module): ...@@ -1628,8 +1634,8 @@ class FusedMoE(torch.nn.Module):
return s return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_output: torch.Tensor, def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
...@@ -1637,8 +1643,8 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, shared ...@@ -1637,8 +1643,8 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, shared
return self.forward_impl(hidden_states, router_logits, shared_output) return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_output: torch.Tensor, def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
......
...@@ -9,7 +9,6 @@ from vllm.triton_utils import tl, triton ...@@ -9,7 +9,6 @@ from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, round_up from vllm.utils import cdiv, round_up
import vllm.envs as envs import vllm.envs as envs
from lightop import op
@triton.jit @triton.jit
...@@ -152,7 +151,8 @@ def moe_align_block_size( ...@@ -152,7 +151,8 @@ def moe_align_block_size(
num_experts: int, num_experts: int,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False, pad_sorted_ids: bool = False,
num_token: Optional[int] = None num_token: Optional[int] = None,
expert_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
...@@ -232,13 +232,17 @@ def moe_align_block_size( ...@@ -232,13 +232,17 @@ def moe_align_block_size(
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
if envs.VLLM_USE_LIGHT_OP: if envs.VLLM_USE_LIGHTOP or expert_mask is not None:
from lightop import op as op
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, None) expert_ids, num_tokens_post_pad,
expert_map = expert_map,
expert_mask = expert_mask,
num_local_tokens = None)
else: else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad) expert_ids, num_tokens_post_pad)
if expert_map is not None: if expert_map is not None:
expert_ids = expert_map[expert_ids] expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad return sorted_ids, expert_ids, num_tokens_post_pad
\ No newline at end of file
...@@ -10,6 +10,7 @@ import vllm.envs as envs ...@@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool: def is_rocm_aiter_rmsnorm_enabled() -> bool:
...@@ -39,6 +40,33 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, ...@@ -39,6 +40,33 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
return out return out
def rms_norm_opt(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
from vllm import _custom_ops as ops
from lightop import fused_rms_norm_contiguous
out = torch.empty_like(x)
fused_rms_norm_contiguous(
out,
x,
weight,
variance_epsilon,
)
return out
def rms_norm_opt_fake(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
return torch.empty_like(x)
direct_register_custom_op(
op_name="rms_norm_opt",
op_func=rms_norm_opt,
mutates_args=[],
fake_impl=rms_norm_opt_fake,
)
def fused_add_rms_norm( def fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
...@@ -187,6 +215,23 @@ class RMSNorm(CustomOp): ...@@ -187,6 +215,23 @@ class RMSNorm(CustomOp):
else: else:
return norm_func(x, self.weight.data, self.variance_epsilon) return norm_func(x, self.weight.data, self.variance_epsilon)
def forward_cuda_opt(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if add_residual:
return norm_func(x, residual, self.weight.data,
self.variance_epsilon)
else:
return torch.ops.vllm.rms_norm_opt(x, self.weight.data, self.variance_epsilon)
def forward_apex( def forward_apex(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -666,7 +666,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -666,7 +666,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None): bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None):
""" """
Use the output of create_weights and the CompressedTensorsScheme Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the associated with the layer to apply the forward pass with the
...@@ -677,7 +678,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -677,7 +678,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
scheme = layer.scheme scheme = layer.scheme
if scheme is None: if scheme is None:
raise ValueError("A scheme must be defined for each layer") raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias) return scheme.apply_weights(layer, x, bias=bias, input_quant_args=input_quant_args)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
......
...@@ -1097,7 +1097,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1097,7 +1097,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
...@@ -1137,7 +1137,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1137,7 +1137,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=False) use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
......
...@@ -111,7 +111,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -111,7 +111,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.kernel.process_weights_after_loading(layer) self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor],
input_quant_args: Optional[list[torch.Tensor]] = None) -> torch.Tensor:
# return self.kernel.apply_weights(layer, x, bias) # return self.kernel.apply_weights(layer, x, bias)
...@@ -122,5 +123,5 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -122,5 +123,5 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point=layer.input_zero_point, input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj, azp_adj=layer.azp_adj,
bias=bias, bias=bias,
w8a8_strategy=self.w8a8_strategy) w8a8_strategy=self.w8a8_strategy,
input_quant_args=input_quant_args)
\ No newline at end of file
...@@ -358,6 +358,7 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -358,6 +358,7 @@ class SlimQuantW4A8Int8MoEMethod:
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -398,4 +399,6 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -398,4 +399,6 @@ class SlimQuantW4A8Int8MoEMethod:
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -392,20 +392,26 @@ def apply_int8_linear( ...@@ -392,20 +392,26 @@ def apply_int8_linear(
azp_adj: Optional[torch.Tensor] = None, azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
w8a8_strategy:Optional[int]=0, w8a8_strategy:Optional[int]=0,
input_quant_args: Optional[list[torch.Tensor]] = None
): ):
# ops.scaled_int8_quant supports both dynamic and static quant. # ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale. # * static, layer.input_scale is scalar and x_scale is input_scale.
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
else: if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input, assert len(input_quant_args) == 2
input_scale, x_zp =None
input_zero_point, x_q, x_scale = input_quant_args
symmetric=symmetric) else: # not USE_FUSED_RMS_QUANT
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
if x_zp is not None: if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token # Currently, static is always per-tensor and dynamic is per-token
......
...@@ -37,6 +37,8 @@ from transformers import PretrainedConfig ...@@ -37,6 +37,8 @@ from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
...@@ -900,6 +902,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -900,6 +902,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def rotary_embedding_deepseek_fuse(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
from lightop import op
op.rotary_embedding_deepseek_fuse(positions, query, key, head_size, cos_sin_cache, is_neox_style)
def rotary_embedding_deepseek_fuse_fake(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
pass
direct_register_custom_op(
op_name="rotary_embedding_deepseek_fuse",
op_func=rotary_embedding_deepseek_fuse,
mutates_args=[],
fake_impl=rotary_embedding_deepseek_fuse_fake,
)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -938,8 +958,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -938,8 +958,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
num_warps=1) num_warps=1)
call(query) # if envs.VLLM_USE_LIGHTOP:
call(key) if False:
torch.ops.vllm.rotary_embedding_deepseek_fuse(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style)
else:
call(query)
call(key)
return query, key return query, key
else: else:
query_rot = query[..., :self.rotary_dim] query_rot = query[..., :self.rotary_dim]
......
...@@ -238,14 +238,28 @@ def get_model_architecture( ...@@ -238,14 +238,28 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
else: else:
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0': if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
else: else:
os.environ['LM_NN'] = '1' os.environ['LM_NN'] = '1'
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1': if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0' os.environ['FA_PAD'] = '0'
else:
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
# awq相关配置 # awq相关配置
try: try:
if os.getenv('AWQ_MOE_SZ') == None: if os.getenv('AWQ_MOE_SZ') == None:
......
...@@ -213,10 +213,30 @@ class DeepseekV2MoE(nn.Module): ...@@ -213,10 +213,30 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts( if envs.VLLM_USE_LIGHTOP and not self.dpsk_fp16_quick:
hidden_states=hidden_states, final_hidden_states = self.experts(
router_logits=router_logits, hidden_states=hidden_states,
shared_output=shared_output) router_logits=router_logits,
shared_output=shared_output)
else:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
...@@ -546,7 +566,10 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -546,7 +566,10 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim) q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe # Add head dim of 1 to k_pe
......
...@@ -216,7 +216,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -216,7 +216,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -928,8 +927,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -928,8 +927,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
else: else:
...@@ -993,8 +993,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -993,8 +993,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
else: else:
......
...@@ -20,7 +20,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, ...@@ -20,7 +20,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm import envs from vllm import envs
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -167,8 +167,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -167,8 +167,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024: if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\ q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1) .unsqueeze(1)
else: else:
......
...@@ -5,7 +5,10 @@ from functools import reduce ...@@ -5,7 +5,10 @@ from functools import reduce
import pytest import pytest
import torch import torch
import math import math
from lightop import ds_cat import vllm.envs as envs
if envs.VLLM_USE_OPT_CAT:
from lightop import ds_cat
def test_concat_Acc_prefill(shape_pair, dim): def test_concat_Acc_prefill(shape_pair, dim):
......
...@@ -1047,16 +1047,14 @@ class Scheduler(SchedulerInterface): ...@@ -1047,16 +1047,14 @@ class Scheduler(SchedulerInterface):
for req in itertools.chain(running_reqs, resumed_reqs): for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id req_id = req.request_id
req_ids.append(req_id) req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] - num_tokens = req.num_generated_token_ids
len(spec_decode_tokens.get(req_id, ())))
if self.use_pp: if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first- # because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't # stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner # need to send the sampled tokens back because the model runner
# will cache them. # will cache them.
token_ids = req.all_token_ids[req.num_computed_tokens:req. token_ids = req.all_token_ids[-num_tokens:]
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids) new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id]) new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens) num_computed_tokens.append(req.num_computed_tokens)
...@@ -1190,6 +1188,7 @@ class Scheduler(SchedulerInterface): ...@@ -1190,6 +1188,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1
if scheduled_spec_token_ids: if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens # num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled # processed in the current step, considering scheduled
...@@ -1197,9 +1196,11 @@ class Scheduler(SchedulerInterface): ...@@ -1197,9 +1196,11 @@ class Scheduler(SchedulerInterface):
# num_computed_tokens is decreased by the number of rejected # num_computed_tokens is decreased by the number of rejected
# tokens, where is given by: # tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids)) len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected request.num_computed_tokens -= num_tokens_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids), num_draft_tokens=len(scheduled_spec_token_ids),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment