Unverified Commit f194e14f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Reduce MoE memory usage (#6147)

parent cfc9f9ab
...@@ -3,10 +3,9 @@ from typing import List, Optional ...@@ -3,10 +3,9 @@ from typing import List, Optional
import torch import torch
import triton import triton
import triton.language as tl
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import is_cuda from sglang.srt.utils import dispose_tensor, is_cuda
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -653,12 +652,15 @@ def grouped_gemm_triton( ...@@ -653,12 +652,15 @@ def grouped_gemm_triton(
scale_a: torch.Tensor = None, scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None, scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
c_dtype=None,
): ):
assert weight_column_major == True # TODO: more assert weight_column_major == True # TODO: more
if use_fp8_w8a8 and block_shape is None: if use_fp8_w8a8 and block_shape is None:
assert scale_a is not None and scale_b is not None assert scale_a is not None and scale_b is not None
if block_shape is not None: if block_shape is not None:
a_original = a
assert len(block_shape) == 2 assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1] block_n, block_k = block_shape[0], block_shape[1]
a, scale_a = per_token_group_quant_fp8(a, block_k) a, scale_a = per_token_group_quant_fp8(a, block_k)
...@@ -667,6 +669,8 @@ def grouped_gemm_triton( ...@@ -667,6 +669,8 @@ def grouped_gemm_triton(
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1] assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
dispose_tensor(a_original)
# TODO: adjust config or tune kernel # TODO: adjust config or tune kernel
# Reduce block size to prevent L40 shared memory overflow. # Reduce block size to prevent L40 shared memory overflow.
config = { config = {
...@@ -680,6 +684,10 @@ def grouped_gemm_triton( ...@@ -680,6 +684,10 @@ def grouped_gemm_triton(
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
) )
if c is None:
assert c_dtype is not None
c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
grid = lambda META: ( grid = lambda META: (
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
......
...@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
_is_hip = is_hip() _is_hip = is_hip()
...@@ -92,6 +92,7 @@ class GroupedGemmRunner(torch.nn.Module): ...@@ -92,6 +92,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_a: torch.Tensor = None, scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None, scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
c_dtype=None,
): ):
if self.use_flashinfer: if self.use_flashinfer:
# TODO: flashinfer # TODO: flashinfer
...@@ -119,6 +120,7 @@ class GroupedGemmRunner(torch.nn.Module): ...@@ -119,6 +120,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_a, scale_a,
scale_b, scale_b,
block_shape=block_shape, block_shape=block_shape,
c_dtype=c_dtype,
) )
return c return c
...@@ -210,6 +212,10 @@ class EPMoE(torch.nn.Module): ...@@ -210,6 +212,10 @@ class EPMoE(torch.nn.Module):
self.grouped_gemm_runner = None self.grouped_gemm_runner = None
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
assert self.quant_method is not None assert self.quant_method is not None
if self.grouped_gemm_runner is None: if self.grouped_gemm_runner is None:
...@@ -265,25 +271,21 @@ class EPMoE(torch.nn.Module): ...@@ -265,25 +271,21 @@ class EPMoE(torch.nn.Module):
hidden_states.shape[1], hidden_states.shape[1],
BLOCK_SIZE=512, BLOCK_SIZE=512,
) )
dispose_tensor(hidden_states)
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
weight_indices_cur_rank = torch.arange( weight_indices_cur_rank = torch.arange(
0, 0,
self.num_experts_per_partition, self.num_experts_per_partition,
device=hidden_states.device, device=hidden_states_device,
dtype=torch.int64, dtype=torch.int64,
) )
# GroupGemm-0 # GroupGemm-0
gateup_output = torch.empty(
gateup_input.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gateup_output = self.grouped_gemm_runner( gateup_output = self.grouped_gemm_runner(
a=gateup_input, a=gateup_input,
b=self.w13_weight, b=self.w13_weight,
c=gateup_output, c=None,
c_dtype=hidden_states_dtype,
batch_size=self.num_experts_per_partition, batch_size=self.num_experts_per_partition,
weight_column_major=True, weight_column_major=True,
seg_indptr=seg_indptr_cur_rank, seg_indptr=seg_indptr_cur_rank,
...@@ -297,6 +299,7 @@ class EPMoE(torch.nn.Module): ...@@ -297,6 +299,7 @@ class EPMoE(torch.nn.Module):
), ),
block_shape=self.block_shape, block_shape=self.block_shape,
) )
del gateup_input
# Act # Act
down_input = torch.empty( down_input = torch.empty(
...@@ -306,14 +309,14 @@ class EPMoE(torch.nn.Module): ...@@ -306,14 +309,14 @@ class EPMoE(torch.nn.Module):
dtype=( dtype=(
self.fp8_dtype self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant) if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype else hidden_states_dtype
), ),
) )
if self.w2_input_scale is None and not self.use_block_quant: if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones( self.w2_input_scale = torch.ones(
self.num_experts_per_partition, self.num_experts_per_partition,
dtype=torch.float32, dtype=torch.float32,
device=hidden_states.device, device=hidden_states_device,
) )
if self.activation == "silu": if self.activation == "silu":
...@@ -340,13 +343,14 @@ class EPMoE(torch.nn.Module): ...@@ -340,13 +343,14 @@ class EPMoE(torch.nn.Module):
) )
else: else:
raise ValueError(f"Unsupported activation: {self.activation=}") raise ValueError(f"Unsupported activation: {self.activation=}")
del gateup_output
# GroupGemm-1 # GroupGemm-1
down_output = torch.empty( down_output = torch.empty(
down_input.shape[0], down_input.shape[0],
self.w2_weight.shape[1], self.w2_weight.shape[1],
device=hidden_states.device, device=hidden_states_device,
dtype=hidden_states.dtype, dtype=hidden_states_dtype,
) )
down_output = self.grouped_gemm_runner( down_output = self.grouped_gemm_runner(
a=down_input, a=down_input,
...@@ -365,10 +369,13 @@ class EPMoE(torch.nn.Module): ...@@ -365,10 +369,13 @@ class EPMoE(torch.nn.Module):
), ),
block_shape=self.block_shape, block_shape=self.block_shape,
) )
del down_input
# PostReorder # PostReorder
output = torch.empty_like(hidden_states) output = torch.empty(
post_reorder_triton_kernel[(hidden_states.size(0),)]( hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output, down_output,
output, output,
src2dst, src2dst,
...@@ -377,7 +384,7 @@ class EPMoE(torch.nn.Module): ...@@ -377,7 +384,7 @@ class EPMoE(torch.nn.Module):
self.start_expert_id, self.start_expert_id,
self.end_expert_id, self.end_expert_id,
self.top_k, self.top_k,
hidden_states.size(1), hidden_states_shape[1],
BLOCK_SIZE=512, BLOCK_SIZE=512,
) )
return output return output
...@@ -881,6 +888,9 @@ class DeepEPMoE(EPMoE): ...@@ -881,6 +888,9 @@ class DeepEPMoE(EPMoE):
reorder_topk_ids: torch.Tensor, reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor, seg_indptr: torch.Tensor,
): ):
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.activation == "silu"
if self.grouped_gemm_runner is None: if self.grouped_gemm_runner is None:
...@@ -903,18 +913,12 @@ class DeepEPMoE(EPMoE): ...@@ -903,18 +913,12 @@ class DeepEPMoE(EPMoE):
) )
# GroupGemm-0 # GroupGemm-0
gateup_output = torch.empty(
hidden_states.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
gateup_output = self.grouped_gemm_runner( gateup_output = self.grouped_gemm_runner(
a=hidden_states, a=hidden_states,
b=self.w13_weight, b=self.w13_weight,
c=gateup_output, c=None,
c_dtype=hidden_states.dtype,
batch_size=self.num_experts_per_partition, batch_size=self.num_experts_per_partition,
weight_column_major=True, weight_column_major=True,
seg_indptr=seg_indptr, seg_indptr=seg_indptr,
...@@ -928,6 +932,13 @@ class DeepEPMoE(EPMoE): ...@@ -928,6 +932,13 @@ class DeepEPMoE(EPMoE):
), ),
block_shape=self.block_shape, block_shape=self.block_shape,
) )
else:
gateup_output = torch.empty(
hidden_states.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# Act # Act
down_input = torch.empty( down_input = torch.empty(
...@@ -937,14 +948,14 @@ class DeepEPMoE(EPMoE): ...@@ -937,14 +948,14 @@ class DeepEPMoE(EPMoE):
dtype=( dtype=(
self.fp8_dtype self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant) if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype else hidden_states_dtype
), ),
) )
if self.w2_input_scale is None and not self.use_block_quant: if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones( self.w2_input_scale = torch.ones(
self.num_experts_per_partition, self.num_experts_per_partition,
dtype=torch.float32, dtype=torch.float32,
device=hidden_states.device, device=hidden_states_device,
) )
if self.activation == "silu": if self.activation == "silu":
...@@ -961,12 +972,14 @@ class DeepEPMoE(EPMoE): ...@@ -961,12 +972,14 @@ class DeepEPMoE(EPMoE):
else: else:
raise ValueError(f"Unsupported activation: {self.activation=}") raise ValueError(f"Unsupported activation: {self.activation=}")
del gateup_output
# GroupGemm-1 # GroupGemm-1
down_output = torch.empty( down_output = torch.empty(
down_input.shape[0], down_input.shape[0],
self.w2_weight.shape[1], self.w2_weight.shape[1],
device=hidden_states.device, device=hidden_states_device,
dtype=hidden_states.dtype, dtype=hidden_states_dtype,
) )
if down_input.shape[0] > 0: if down_input.shape[0] > 0:
down_output = self.grouped_gemm_runner( down_output = self.grouped_gemm_runner(
...@@ -1007,11 +1020,9 @@ class DeepEPMoE(EPMoE): ...@@ -1007,11 +1020,9 @@ class DeepEPMoE(EPMoE):
N = self.w13_weight.size(1) N = self.w13_weight.size(1)
scale_block_size = 128 scale_block_size = 128
gather_out = torch.empty_like( hidden_states_fp8_shape = hidden_states_fp8.shape
hidden_states_fp8, hidden_states_fp8_device = hidden_states_fp8.device
device=hidden_states_fp8.device, hidden_states_fp8_dtype = hidden_states_fp8.dtype
dtype=torch.bfloat16,
)
input_tensor = [ input_tensor = [
torch.empty( torch.empty(
...@@ -1049,16 +1060,18 @@ class DeepEPMoE(EPMoE): ...@@ -1049,16 +1060,18 @@ class DeepEPMoE(EPMoE):
m_indices, m_indices,
output_index, output_index,
) )
dispose_tensor(hidden_states_fp8)
gateup_output = torch.empty( gateup_output = torch.empty(
(all_tokens, N), (all_tokens, N),
device=hidden_states_fp8.device, device=hidden_states_fp8_device,
dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
input_tensor[1] = tma_align_input_scale(input_tensor[1]) input_tensor[1] = tma_align_input_scale(input_tensor[1])
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices input_tensor, self.w13_weight_fp8, gateup_output, m_indices
) )
del input_tensor
down_input = torch.empty( down_input = torch.empty(
( (
all_tokens, all_tokens,
...@@ -1068,14 +1081,16 @@ class DeepEPMoE(EPMoE): ...@@ -1068,14 +1081,16 @@ class DeepEPMoE(EPMoE):
dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
silu_and_mul(gateup_output.view(-1, N), down_input) silu_and_mul(gateup_output.view(-1, N), down_input)
del gateup_output
down_output = torch.empty( down_output = torch.empty(
(all_tokens, K), (all_tokens, K),
device=hidden_states_fp8.device, device=hidden_states_fp8_device,
dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input, scale_block_size down_input, scale_block_size
) )
del down_input
down_input_scale = tma_align_input_scale(down_input_scale) down_input_scale = tma_align_input_scale(down_input_scale)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(down_input_fp8, down_input_scale), (down_input_fp8, down_input_scale),
...@@ -1083,7 +1098,13 @@ class DeepEPMoE(EPMoE): ...@@ -1083,7 +1098,13 @@ class DeepEPMoE(EPMoE):
down_output, down_output,
m_indices, m_indices,
) )
del down_input_fp8, down_input_scale
gather_out = torch.empty(
hidden_states_fp8_shape,
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out) ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
return gather_out return gather_out
...@@ -1107,6 +1128,7 @@ class DeepEPMoE(EPMoE): ...@@ -1107,6 +1128,7 @@ class DeepEPMoE(EPMoE):
m_grouped_gemm_fp8_fp8_bf16_nt_masked( m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
) )
dispose_tensor(hidden_states_fp8[0])
# Act # Act
down_input = torch.empty( down_input = torch.empty(
...@@ -1135,6 +1157,7 @@ class DeepEPMoE(EPMoE): ...@@ -1135,6 +1157,7 @@ class DeepEPMoE(EPMoE):
scale_block_size, scale_block_size,
masked_m, masked_m,
) )
del gateup_output
# GroupGemm-1 # GroupGemm-1
n = self.w2_weight.size(1) n = self.w2_weight.size(1)
......
...@@ -311,10 +311,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -311,10 +311,10 @@ class DeepseekV2MoE(nn.Module):
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
final_hidden_states = ( final_hidden_states = self.experts(
self.experts(hidden_states=hidden_states, router_logits=router_logits) hidden_states=hidden_states, router_logits=router_logits
* self.routed_scaling_factor
) )
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1: if self.tp_size > 1:
......
...@@ -2100,3 +2100,7 @@ def log_info_on_rank0(logger, msg): ...@@ -2100,3 +2100,7 @@ def log_info_on_rank0(logger, msg):
if get_tensor_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
logger.info(msg) logger.info(msg)
def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment