Unverified Commit 7722c11c authored by HAI's avatar HAI Committed by GitHub
Browse files

Regression fix to AMD/ROCm from recent change (#2606)

parent b2ed5c8e
...@@ -11,12 +11,17 @@ from typing import Any, Callable, Dict, List, Optional, Tuple ...@@ -11,12 +11,17 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import direct_register_custom_op, get_device_name from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
not_hip = False
if not is_hip():
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
not_hip = True
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
...@@ -268,7 +273,7 @@ def moe_align_block_size( ...@@ -268,7 +273,7 @@ def moe_align_block_size(
) )
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
# FIXME(zhyncs) # FIXME(zhyncs)
if num_experts >= 256: if not_hip and num_experts >= 256:
sgl_moe_align_block_size( sgl_moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment