Unverified Commit 60e2fdcf authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

use sgl-kernel moe_align_block_size (#2581)


Co-authored-by: default avatarispobock <ispobaoke@163.com>
Co-authored-by: default avatarHandH1998 <1335248067@qq.com>
parent d7c0e872
...@@ -21,9 +21,9 @@ runtime_common = ["aiohttp", "decord", "fastapi", ...@@ -21,9 +21,9 @@ runtime_common = ["aiohttp", "decord", "fastapi",
"orjson", "outlines>=0.0.44,<0.1.0", "orjson", "outlines>=0.0.44,<0.1.0",
"packaging", "pillow", "prometheus-client>=0.20.0", "packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart", "psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "gemlite", "uvicorn", "uvloop", "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
"xgrammar>=0.1.6"] "xgrammar>=0.1.6"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6"] srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6", "sgl-kernel"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl # => base docker rocm/vllm-dev:20241022, not from public vllm whl
......
...@@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple ...@@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -266,9 +267,25 @@ def moe_align_block_size( ...@@ -266,9 +267,25 @@ def moe_align_block_size(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
) )
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
ops.moe_align_block_size( # FIXME(zhyncs)
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad if num_experts >= 256:
) sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
return sorted_ids, expert_ids, num_tokens_post_pad return sorted_ids, expert_ids, num_tokens_post_pad
......
...@@ -95,6 +95,12 @@ class ModelRunner: ...@@ -95,6 +95,12 @@ class ModelRunner:
): ):
logger.info("MLA optimization is turned on. Use triton backend.") logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton" self.server_args.attention_backend = "triton"
# FIXME(HandH1998)
if (
"DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures
and not self.server_args.disable_cuda_graph
):
self.server_args.disable_cuda_graph = True
if self.server_args.enable_double_sparsity: if self.server_args.enable_double_sparsity:
logger.info( logger.info(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment