".github/vscode:/vscode.git/clone" did not exist on "7835904647d37c9eff25c2cea3801294a85c5cf2"
Commit 3d062a1c authored by zhuwenwen's avatar zhuwenwen
Browse files

update fusedmoe

parent fe79f042
......@@ -4,7 +4,7 @@
import functools
import json
import os
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, List, Optional, Tuple
import torch
......@@ -31,7 +31,7 @@ from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
......@@ -502,7 +502,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int],] = None,
block_shape: Optional[List[int],] = None,
use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
......@@ -534,7 +534,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
EM = min(sorted_token_ids.size(0),
A.size(0) * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.size(1) if not use_nn_moe else B.size[2], META['BLOCK_SIZE_N']), )
B.size(1) if not use_nn_moe else B.size(2), META['BLOCK_SIZE_N']), )
if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0:
......@@ -621,7 +621,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.size(1) if not use_nn_moe else B.size[2],
B.size(1) if not use_nn_moe else B.size(2),
B.size(1),
EM,
num_tokens,
......@@ -660,7 +660,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
def get_config_file_name(E: int,
N: int,
dtype: Optional[str],
block_shape: Optional[list[int]] = None, use_nn_moe: Optional[bool] = False) -> str:
block_shape: Optional[List[int]] = None, use_nn_moe: Optional[bool] = False) -> str:
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = ("" if not block_shape or not all(block_shape) else
......@@ -783,7 +783,7 @@ def get_default_config(
topk: int,
dtype: Optional[str],
is_marlin: bool,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False,
) -> dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None:
......@@ -846,7 +846,7 @@ def try_get_optimal_moe_config(
dtype: Optional[str],
M: int,
is_marlin: bool = False,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
) -> dict[str, int]:
from vllm.model_executor.layers.fused_moe import get_config
......@@ -893,9 +893,9 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return rocm_aiter_topk_softmax
# if is_rocm_aiter_moe_enabled():
# from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
# return rocm_aiter_topk_softmax
return vllm_topk_softmax
......@@ -1033,7 +1033,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
......@@ -1064,7 +1064,7 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
pass
......@@ -1099,7 +1099,7 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
......@@ -1130,7 +1130,7 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1183,7 +1183,7 @@ def fused_experts(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
......@@ -1268,7 +1268,7 @@ def fused_experts_impl(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
# Check constraints.
......@@ -1276,7 +1276,7 @@ def fused_experts_impl(
assert hidden_states.size(1) // 2 == w1.size(2), (
"Hidden size mismatch")
elif use_nn_moe:
assert hidden_states.size[1] == w1.size[1], "Hidden size mismatch"
assert hidden_states.size(1) == w1.size(1), "Hidden size mismatch"
else:
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
......@@ -1394,6 +1394,7 @@ def fused_experts_impl(
w1_scale,
w1_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
......@@ -1432,6 +1433,7 @@ def fused_experts_impl(
w2_scale,
w2_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
......@@ -1479,7 +1481,7 @@ def fused_moe(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
"""
......@@ -1525,7 +1527,7 @@ def fused_moe(
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[list[int]]): Optional block size for block-wise
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns:
......@@ -1577,7 +1579,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
):
super().__init__(
FusedMoEQuantConfig.make(
......@@ -1762,7 +1764,7 @@ def modular_triton_fused_moe(
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
block_shape: Optional[List[int]] = None,
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
......
......@@ -51,10 +51,10 @@ else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
FusedMoEPrepareAndFinalize = None # type: ignore
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk)
elif current_platform.is_cpu():
# if is_rocm_aiter_moe_enabled():
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
# rocm_aiter_grouped_topk as grouped_topk)
if current_platform.is_cpu():
pass
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment