"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "27ca23dc002e06eade014ac6b801dc2dcbea40f3"
Commit ffd123f6 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix: 修复pp资源抢占bug,修复重复判断逻辑

fuse_moe_fp8接入marlin算子
fix(v1):修复抢占恢复时 BlockTable 溢出
feat(moe):新增 VLLM_USE_MOE_W16A16_TRTION 强制 Triton MoE
fix: 解决原版0消耗chunk-prefill崩溃问题
fp8增加fused_moe_gate参数
parent ed2c06c3
...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt4.' + sha[:7] version = 'das.opt5.' + sha[:7]
else: else:
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt4' version = 'das.opt5'
# dtk version # dtk version
......
...@@ -208,6 +208,7 @@ if TYPE_CHECKING: ...@@ -208,6 +208,7 @@ if TYPE_CHECKING:
VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS: int = 0 VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS: int = 0
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT: int = -1 VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT: int = -1
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1 VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1
VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ENABLE_DEEPEP_INT8_DISPATCH: bool = True VLLM_ENABLE_DEEPEP_INT8_DISPATCH: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
...@@ -1346,6 +1347,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1346,6 +1347,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Only capture when num_tokens < N (0 disables). # Only capture when num_tokens < N (0 disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT": "VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT", "-1")), lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT", "-1")),
# Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
# vLLM will use deepgemm kernel for deepep ht mode # vLLM will use deepgemm kernel for deepep ht mode
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM": "VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
......
...@@ -1726,6 +1726,11 @@ def fused_experts_impl( ...@@ -1726,6 +1726,11 @@ def fused_experts_impl(
or getattr(w2, "marlin_w16a16_packed", False) or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)) or _is_marlin_w16a16_packed(w1, w2))
if is_packed: if is_packed:
if envs.VLLM_USE_MOE_W16A16_TRITON:
raise RuntimeError(
"VLLM_USE_MOE_W16A16_TRITON=1 forces Triton MoE, but the MoE weights are "
"packed in Marlin W16A16 layout. Please load unpacked weights or set "
"VLLM_USE_MOE_W16A16_TRITON=0.")
try: try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin) fused_experts_impl_w16a16_marlin)
......
...@@ -101,8 +101,6 @@ def _is_marlin_w16a16_moe_supported( ...@@ -101,8 +101,6 @@ def _is_marlin_w16a16_moe_supported(
return False return False
if E <= 0 or N <= 0 or K <= 0 or top_k <= 0: if E <= 0 or N <= 0 or K <= 0 or top_k <= 0:
return False return False
if not envs.VLLM_USE_LIGHTOP:
return False
try: try:
from lightop import get_moe_cuda_marlin_config_w16a16 from lightop import get_moe_cuda_marlin_config_w16a16
...@@ -1048,7 +1046,9 @@ class FusedMoE(torch.nn.Module): ...@@ -1048,7 +1046,9 @@ class FusedMoE(torch.nn.Module):
# Not considering quant for now, temporarily # Not considering quant for now, temporarily
moe_in_dtype = model_dtype moe_in_dtype = model_dtype
self._marlin_w16a16_moe_enabled = ( self._marlin_w16a16_moe_enabled = (
params_dtype == moe_in_dtype and self.activation == "silu" not envs.VLLM_USE_MOE_W16A16_TRITON
and params_dtype == moe_in_dtype
and self.activation == "silu"
and not self.apply_router_weight_on_input and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported( and _is_marlin_w16a16_moe_supported(
E=self.local_num_experts, E=self.local_num_experts,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from typing import Callable, Optional from typing import Callable, Optional
from vllm import envs
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter from torch.nn import Parameter
...@@ -61,6 +62,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -61,6 +62,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
# If channelwise, scales are already lined up, so just transpose. # If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL: elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight weight = layer.weight
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.t()
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None) input_scale = getattr(layer, 'input_scale', None)
......
...@@ -858,6 +858,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -858,6 +858,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_fused_gate: Optional[bool] = False,
enable_eplb: bool = False, enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
...@@ -886,6 +887,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -886,6 +887,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_load_view=expert_load_view, expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map, logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count, logical_replica_count=logical_replica_count,
use_fused_gate=use_fused_gate,
) )
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
......
...@@ -11,9 +11,11 @@ from vllm.config import CompilationLevel, get_current_vllm_config ...@@ -11,9 +11,11 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.quantize import quant_ops
try: try:
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8 from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
from lmslim.quantize.quant_ops import hipblaslt_w8a8_channelwise_gemm
except Exception: except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n") print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
...@@ -257,6 +259,41 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -257,6 +259,41 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
return output return output
def hipblaslt_w8a8_channelwise_scaled_mm(
qinput: torch.Tensor,
input_2d: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs
) -> torch.Tensor:
assert qinput.is_contiguous() and weight.is_contiguous()
assert qinput.shape[-1] == weight.shape[-1]
assert qinput.dtype == weight.dtype
m = qinput.shape[0]
k = qinput.shape[1]
n = weight.shape[0]
success, output = quant_ops.hipblaslt_w8a8_channelwise_gemm(
a = qinput,
b = weight,
scale_a = scale_a,
scale_b = scale_b,
m = m,
n = n,
k = k,
transpose_flag = "NT",
out_dtype = out_dtype,
bias = bias,
)
return output.view(m, n)
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
...@@ -316,6 +353,8 @@ def dispatch_w8a8_scaled_mm( ...@@ -316,6 +353,8 @@ def dispatch_w8a8_scaled_mm(
if current_platform.is_rocm(): if current_platform.is_rocm():
return rocm_per_tensor_w8a8_scaled_mm return rocm_per_tensor_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm return torch_per_tensor_w8a8_scaled_mm
if envs.VLLM_W8A8_BACKEND == 3:
return hipblaslt_w8a8_channelwise_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only # torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token # so fallback to naive if per channel or per token
if (use_per_token_if_dynamic and not per_tensor_weights if (use_per_token_if_dynamic and not per_tensor_weights
......
...@@ -281,20 +281,26 @@ class Scheduler(SchedulerInterface): ...@@ -281,20 +281,26 @@ class Scheduler(SchedulerInterface):
num_draft_tokens=num_draft_tokens, num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens) num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None: if new_blocks is None:
if self.use_pp:
preemptable_reqs = [r for r in self.running if
r.num_tokens_with_spec != r.num_computed_tokens]
else:
preemptable_reqs = self.running
# The request cannot be scheduled. # The request cannot be scheduled.
# Preempt the lowest-priority request. # Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY: if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max( preempted_req = max(
self.running, preemptable_reqs,
key=lambda r: (r.priority, r.arrival_time), key=lambda r: (r.priority, r.arrival_time),
) )
self.running.remove(preempted_req)
else: else:
preempted_req = self.running.pop() preempted_req = preemptable_reqs[-1]
self.running.remove(preempted_req)
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.spec_token_ids = []
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
...@@ -901,20 +907,26 @@ class Scheduler(SchedulerInterface): ...@@ -901,20 +907,26 @@ class Scheduler(SchedulerInterface):
num_draft_tokens=num_draft_tokens, num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens) num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None: if new_blocks is None:
if self.use_pp:
preemptable_reqs = [r for r in self.running if
r.num_tokens_with_spec != r.num_computed_tokens]
else:
preemptable_reqs = self.running
# The request cannot be scheduled. # The request cannot be scheduled.
# Preempt the lowest-priority request. # Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY: if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max( preempted_req = max(
self.running, preemptable_reqs,
key=lambda r: (r.priority, r.arrival_time), key=lambda r: (r.priority, r.arrival_time),
) )
self.running.remove(preempted_req)
else: else:
preempted_req = self.running.pop() preempted_req = preemptable_reqs[-1]
self.running.remove(preempted_req)
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.spec_token_ids = []
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
......
...@@ -545,7 +545,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -545,7 +545,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens) num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index) if resumed_from_preemption:
self.input_batch.block_table.add_row(new_block_ids, req_index)
else:
self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu # For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached. # because the sampled tokens are already cached.
......
...@@ -796,6 +796,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -796,6 +796,7 @@ class V1ZeroModelRunner(GPUModelRunner):
req_state = self.requests[req_id] req_state = self.requests[req_id]
token_idx = self.last_sampled_token_lens[req_idx] token_idx = self.last_sampled_token_lens[req_idx]
if token_idx == -1: if token_idx == -1:
self.fix_sampled_token_ids[req_idx].clear()
continue continue
fix_len = len(self.fix_sampled_token_ids[req_idx]) fix_len = len(self.fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx] req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment