Unverified Commit 64c87135 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

remove activation dependency in fused_moe (#3433)

parent 1646149a
...@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 ...@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
is_hip_flag = is_hip() is_hip_flag = is_hip()
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
...@@ -27,6 +27,15 @@ enable_moe_align_block_size_triton = bool( ...@@ -27,6 +27,15 @@ enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
) )
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_rocm = torch.cuda.is_available() and torch.version.hip
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
if _is_cuda or _is_rocm:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
@triton.jit @triton.jit
def fused_moe_kernel( def fused_moe_kernel(
...@@ -989,9 +998,15 @@ def fused_experts_impl( ...@@ -989,9 +998,15 @@ def fused_experts_impl(
) )
if activation == "silu": if activation == "silu":
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) if _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
elif activation == "gelu": elif activation == "gelu":
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
else: else:
raise ValueError(f"Unsupported activation: {activation=}") raise ValueError(f"Unsupported activation: {activation=}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment