Commit 22c6c645 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix moe params and run error

parent de5774fa
...@@ -27,7 +27,8 @@ if current_platform.is_cuda(): ...@@ -27,7 +27,8 @@ if current_platform.is_cuda():
except ImportError: except ImportError:
_flashmla_extension_C_AVAILABLE = False _flashmla_extension_C_AVAILABLE = False
else: else:
_flashmla_extension_C_AVAILABLE = False _flashmla_extension_C_AVAILABLE = True
_flashmla_extension_C_AVAILABLE = True
if current_platform.is_rocm(): if current_platform.is_rocm():
import flash_mla_cuda import flash_mla_cuda
...@@ -42,7 +43,7 @@ def _is_flashmla_available() -> tuple[bool, str | None]: ...@@ -42,7 +43,7 @@ def _is_flashmla_available() -> tuple[bool, str | None]:
"compiled due to insufficient nvcc version or a supported arch " "compiled due to insufficient nvcc version or a supported arch "
"was not in the list of target arches to compile for.", "was not in the list of target arches to compile for.",
) )
if not _flashmla_extension_C_AVAILABLE: if not _flashmla_extension_C_AVAILABLE or not current_platform.is_rocm():
return ( return (
False, False,
"vllm._flashmla_extension_C is not available, likely " "vllm._flashmla_extension_C is not available, likely "
......
...@@ -1747,7 +1747,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1747,7 +1747,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will use optimized topk_softmax + renormalize # vLLM will use optimized topk_softmax + renormalize
"VLLM_USE_TOPK_RENORM": "VLLM_USE_TOPK_RENORM":
lambda: lambda:
(os.environ.get("VLLM_USE_TOPK_RENORM", "True").lower() in (os.environ.get("VLLM_USE_TOPK_RENORM", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use fused RMS + RoPE kernel # vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE": "VLLM_USE_FUSED_RMS_ROPE":
......
...@@ -670,7 +670,7 @@ def fused_moe_kernel( ...@@ -670,7 +670,7 @@ def fused_moe_kernel(
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr, # SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr, top_k: tl.constexpr,
compute_type: tl.constexpr, compute_type: tl.constexpr,
...@@ -1185,7 +1185,7 @@ def zero_experts_compute_triton( ...@@ -1185,7 +1185,7 @@ def zero_experts_compute_triton(
# Adapted from: https://github.com/sgl-project/sglang/pull/2628 # Adapted from: https://github.com/sgl-project/sglang/pull/2628
def get_config_file_name( def get_config_file_name(
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None E: int, N: int, dtype: str | None, block_shape: list[int] | None = None, use_nn_moe: bool | None = False,
) -> str: ) -> str:
# device_name = current_platform.get_device_name().replace(" ", "_") # device_name = current_platform.get_device_name().replace(" ", "_")
# # Set device_name to H200 if a device from the H200 family is detected # # Set device_name to H200 if a device from the H200 family is detected
......
...@@ -116,6 +116,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -116,6 +116,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
enable_eplb: bool = False, enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None, expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
...@@ -289,6 +289,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -289,6 +289,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: torch.Tensor | None = None, expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb: if enable_eplb:
assert expert_load_view is not None assert expert_load_view is not None
...@@ -316,6 +318,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -316,6 +318,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view=expert_load_view, expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map, logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count, logical_replica_count=logical_replica_count,
use_nn_moe=use_nn_moe,
use_fused_gate=use_fused_gate,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -351,10 +355,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -351,10 +355,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: torch.Tensor | None = None, expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts( topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_fused_gate=use_fused_gate,
) )
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
...@@ -391,6 +398,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -391,6 +398,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
use_nn_moe=use_nn_moe,
) )
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
......
...@@ -104,6 +104,7 @@ if current_platform.is_cuda_alike(): ...@@ -104,6 +104,7 @@ if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
elif current_platform.is_xpu(): elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
from vllm.utils import W8a8GetCacheJSON
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1540,7 +1541,7 @@ class DeepseekV2ForCausalLM( ...@@ -1540,7 +1541,7 @@ class DeepseekV2ForCausalLM(
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.tritonsingleton.topk = config.num_experts_per_tok self.tritonsingleton.topk = self.config.num_experts_per_tok
self.tritonsingleton.quant_method=self.quant_method self.tritonsingleton.quant_method=self.quant_method
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment