Commit 22c6c645 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix moe params and run error

parent de5774fa
......@@ -27,7 +27,8 @@ if current_platform.is_cuda():
except ImportError:
_flashmla_extension_C_AVAILABLE = False
else:
_flashmla_extension_C_AVAILABLE = False
_flashmla_extension_C_AVAILABLE = True
_flashmla_extension_C_AVAILABLE = True
if current_platform.is_rocm():
import flash_mla_cuda
......@@ -42,7 +43,7 @@ def _is_flashmla_available() -> tuple[bool, str | None]:
"compiled due to insufficient nvcc version or a supported arch "
"was not in the list of target arches to compile for.",
)
if not _flashmla_extension_C_AVAILABLE:
if not _flashmla_extension_C_AVAILABLE or not current_platform.is_rocm():
return (
False,
"vllm._flashmla_extension_C is not available, likely "
......
......@@ -1747,7 +1747,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will use optimized topk_softmax + renormalize
"VLLM_USE_TOPK_RENORM":
lambda:
(os.environ.get("VLLM_USE_TOPK_RENORM", "True").lower() in
(os.environ.get("VLLM_USE_TOPK_RENORM", "False").lower() in
("true", "1")),
# vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE":
......
......@@ -670,7 +670,7 @@ def fused_moe_kernel(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
# SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
......@@ -1185,7 +1185,7 @@ def zero_experts_compute_triton(
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def get_config_file_name(
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None, use_nn_moe: bool | None = False,
) -> str:
# device_name = current_platform.get_device_name().replace(" ", "_")
# # Set device_name to H200 if a device from the H200 family is detected
......
......@@ -289,6 +289,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
......@@ -316,6 +318,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
use_nn_moe=use_nn_moe,
use_fused_gate=use_fused_gate,
)
def get_fused_moe_quant_config(
......@@ -351,10 +355,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_fused_gate=use_fused_gate,
)
if self.rocm_aiter_moe_enabled:
......@@ -391,6 +398,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=use_nn_moe,
)
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
......
......@@ -104,6 +104,7 @@ if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
from vllm.utils import W8a8GetCacheJSON
logger = init_logger(__name__)
......@@ -1540,7 +1541,7 @@ class DeepseekV2ForCausalLM(
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.tritonsingleton= W8a8GetCacheJSON()
self.tritonsingleton.topk = config.num_experts_per_tok
self.tritonsingleton.topk = self.config.num_experts_per_tok
self.tritonsingleton.quant_method=self.quant_method
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment