Unverified Commit e179e0b7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

update sgl-kernel for EP: python part (#8550)

parent d9049592
...@@ -54,7 +54,7 @@ runtime_common = [ ...@@ -54,7 +54,7 @@ runtime_common = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.2.7", "sgl-kernel==0.2.8",
"torch==2.7.1", "torch==2.7.1",
"torchaudio==2.7.1", "torchaudio==2.7.1",
"torchvision==0.22.1", "torchvision==0.22.1",
......
...@@ -648,7 +648,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -648,7 +648,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda: if _is_cuda:
assert_pkg_version( assert_pkg_version(
"sgl-kernel", "sgl-kernel",
"0.2.7", "0.2.8",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
) )
......
...@@ -568,7 +568,7 @@ def moe_align_block_size( ...@@ -568,7 +568,7 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible - The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
sorted_ids = torch.empty( sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
) )
...@@ -578,13 +578,9 @@ def moe_align_block_size( ...@@ -578,13 +578,9 @@ def moe_align_block_size(
) )
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
cumsum_buffer = torch.empty( cumsum_buffer = torch.empty(
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device (num_experts + 2,), dtype=torch.int32, device=topk_ids.device
)
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
) )
# Threshold based on benchmark results # Threshold based on benchmark results
...@@ -594,12 +590,11 @@ def moe_align_block_size( ...@@ -594,12 +590,11 @@ def moe_align_block_size(
sgl_moe_align_block_size( sgl_moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts + 1,
block_size, block_size,
sorted_ids, sorted_ids,
expert_ids, expert_ids,
num_tokens_post_pad, num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
fuse_sorted_ids_padding, fuse_sorted_ids_padding,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment