Unverified Commit 3d32e4a3 authored by xiaobochen's avatar xiaobochen Committed by GitHub
Browse files

Resubmit MoE-EP (#2371)

parent 64fceab8
......@@ -105,6 +105,12 @@ jobs:
cd test/srt
python3 test_update_weights_from_distributed.py
- name: Evaluate MoE EP accuracy (TP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_moe_ep.py
performance-test-1-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner
......
import logging
from typing import Optional
import torch
import triton
import triton.language as tl
logger = logging.getLogger(__name__)
@triton.jit
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
expert = tl.program_id(0)
low = 0
high = num_toks - 1
target_location = -1
while low <= high:
mid = (low + high) // 2
if tl.load(reorder_topk_ids + mid) > expert:
high = mid - 1
else:
low = mid + 1
target_location = mid
tl.store(seg_indptr + expert + 1, target_location + 1)
@triton.jit
def compute_src2dst_triton_kernel(
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
tl.store(src2dst + src_id, dst_id, mask=mask)
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
compute_seg_indptr_triton_kernel[(num_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)
BLOCK_SIZE = 512
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
)
return reorder_topk_ids, src2dst, seg_indptr
@triton.jit
def pre_reorder_triton_kernel(
input_ptr,
gateup_input_ptr,
src2dst_ptr,
topk_ids_ptr,
a1_scales_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
OutDtype = gateup_input_ptr.dtype.element_ty
src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
if a1_scales_ptr is not None:
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
else:
scale = 1.0
dst_idx = tl.load(src2dst_ptr + idx)
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
out_data = (in_data * scale).to(OutDtype)
tl.store(dst_ptr + offset, out_data, mask=mask)
@triton.jit
def silu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# silu & mul & quantize
gate_output = gate_output * tl.sigmoid(gate_output)
gate_output = gate_output.to(InDtype)
silu_mul_output = gate_output * up_output * scale
silu_mul_output = silu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
@triton.jit
def post_reorder_triton_kernel(
down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty
src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk
computed = False
store_ptr = output_ptr + src_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True
dst_idx = tl.load(src2dst_ptr + idx)
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)
if computed == False:
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
tl.store(
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
)
@triton.jit
def compute_m_range(
pid,
batch_size,
seg_indptr,
weight_indices,
m_num_tiles_indptr,
BLOCK_SIZE_M: tl.constexpr,
):
idx = 0
for bs in range(batch_size):
tiles = tl.load(m_num_tiles_indptr + bs)
if pid >= tiles:
idx = bs
idx_start = tl.load(m_num_tiles_indptr + idx)
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
expert_id = tl.load(weight_indices + idx)
return m_range_start, m_range_end, expert_id
@triton.jit
def grouped_gemm_triton_kernel(
a,
b,
c,
batch_size,
N,
K,
seg_indptr,
weight_indices,
m_num_tiles_indptr,
use_fp8_w8a8,
scale_a,
scale_b,
a_stride_0: tl.constexpr,
b_stride_0: tl.constexpr,
b_stride_1: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
c_dtype = c.dtype.element_ty
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
if pid_m >= total_m_block:
return
m_range_start, m_range_end, expert_id = compute_m_range(
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
)
if m_range_end - m_range_start == 0:
return
n_range_start = pid_n * BLOCK_SIZE_N
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
b_ptr = b + (
(expert_id * b_stride_0)
+ (n_range_start + offs_bn[:, None]) * b_stride_1
+ offs_k[None, :]
)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a_tile = tl.load(
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
)
b_tile = tl.load(
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
)
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
a_ptr += BLOCK_SIZE_K
b_ptr += BLOCK_SIZE_K
if use_fp8_w8a8:
scale_a_value = tl.load(scale_a + expert_id)
scale_b_value = tl.load(scale_b + expert_id)
accumulator *= scale_a_value * scale_b_value
c_tile = accumulator.to(c_dtype)
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
tl.store(c_ptr, c_tile, mask=c_mask)
@triton.jit
def compute_m_num_tiles_indptr(
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
):
for bs in range(batch_size):
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
def grouped_gemm_triton(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
batch_size: int,
weight_column_major: bool,
seg_indptr: Optional[torch.Tensor] = None,
weight_indices: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None,
):
assert weight_column_major == True # TODO: more
if use_fp8_w8a8:
assert scale_a is not None and scale_b is not None
config = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
}
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
compute_m_num_tiles_indptr[(1,)](
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
)
grid = lambda META: (
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
)
grouped_gemm_triton_kernel[grid](
a,
b,
c,
batch_size,
b.size(1),
b.size(2),
seg_indptr,
weight_indices,
m_num_tiles_indptr,
use_fp8_w8a8,
scale_a,
scale_b,
a.stride(0),
b.stride(0),
b.stride(1),
**config,
)
return c
This diff is collapsed.
......@@ -58,6 +58,7 @@ global_server_args_dict = {
"torchao_config": ServerArgs.torchao_config,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
}
......
......@@ -141,6 +141,7 @@ class ModelRunner:
"torchao_config": server_args.torchao_config,
"enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
}
)
......
......@@ -31,6 +31,7 @@ from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.ep_moe.layer import EPMoE
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -113,12 +114,12 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
self.experts = FusedMoE(
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
......@@ -834,7 +835,8 @@ class DeepseekV2ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
......
......@@ -21,9 +21,13 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import MixtralConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.ep_moe.layer import EPMoE
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -38,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
......@@ -63,6 +68,7 @@ class MixtralMoE(nn.Module):
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
......@@ -74,14 +80,13 @@ class MixtralMoE(nn.Module):
quant_config=None,
prefix=f"{prefix}.gate",
)
self.experts = FusedMoE(
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
......@@ -95,6 +100,8 @@ class MixtralMoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(orig_shape)
......@@ -319,7 +326,8 @@ class MixtralForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
......
......@@ -93,6 +93,8 @@ class ServerArgs:
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# Expert parallelism
ep_size: int = 1
# Multi-node distributed serving
dist_init_addr: Optional[str] = None
......@@ -130,6 +132,7 @@ class ServerArgs:
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_ep_moe: bool = False
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None
......@@ -216,6 +219,12 @@ class ServerArgs:
"Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap scheduler is disabled."
)
# Expert parallelism
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.info(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# GGUF
if (
......@@ -526,6 +535,14 @@ class ServerArgs:
"shortest_queue",
],
)
# Expert parallelism
parser.add_argument(
"--expert-parallel-size",
"--ep-size",
type=int,
default=ServerArgs.ep_size,
help="The expert parallelism size.",
)
# Multi-node distributed serving
parser.add_argument(
......@@ -681,6 +698,11 @@ class ServerArgs:
action="store_true",
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
)
parser.add_argument(
"--enable-ep-moe",
action="store_true",
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
)
parser.add_argument(
"--enable-torch-compile",
action="store_true",
......@@ -760,6 +782,7 @@ class ServerArgs:
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
args.dp_size = args.data_parallel_size
args.ep_size = args.expert_parallel_size
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestEpMoE(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--ep-size",
"2",
"--enable-ep-moe",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.5
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.8
class TestEpMoEFP8(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--ep-size",
"2",
"--enable-ep-moe",
"--quantization",
"fp8",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.5
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.8
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment