Unverified Commit 613b197e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Remove one kernel in per_tensor_quant_mla_fp8 (#5549)

parent d58e3544
...@@ -58,10 +58,8 @@ if _is_cuda: ...@@ -58,10 +58,8 @@ if _is_cuda:
): ):
_enable_jit_deepgemm = True _enable_jit_deepgemm = True
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if supports_custom_op(): if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt( def deep_gemm_fp8_fp8_bf16_nt(
...@@ -897,16 +895,20 @@ def _per_tensor_quant_mla_fp8_stage2( ...@@ -897,16 +895,20 @@ def _per_tensor_quant_mla_fp8_stage2(
def per_tensor_quant_mla_fp8( def per_tensor_quant_mla_fp8(
x: torch.Tensor, eps: float = 1e-12 x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
This function quantizes input values to float8 values with tensor-wise quantization This function quantizes input values to float8 values with tensor-wise quantization
and specialized for mla absorbed case. and specialized for mla absorbed case.
""" """
assert x.dim() == 3, "`x` is not a 3d-tensor" assert x.dim() == 3, "`x` is not a 3d-tensor"
assert (
x_s_out.shape == (1,)
and x_s_out.dtype == torch.float32
and x_s_out.device == x.device
)
x_q = x.new_empty(x.size(), dtype=_fp8_type) x_q = x.new_empty(x.size(), dtype=_fp8_type)
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
num_head, num_seq, head_size = x.shape num_head, num_seq, head_size = x.shape
BLOCK_SIZE = triton.next_power_of_2(head_size) BLOCK_SIZE = triton.next_power_of_2(head_size)
...@@ -914,7 +916,7 @@ def per_tensor_quant_mla_fp8( ...@@ -914,7 +916,7 @@ def per_tensor_quant_mla_fp8(
_per_tensor_quant_mla_fp8_stage1[grid]( _per_tensor_quant_mla_fp8_stage1[grid](
x, x,
x_s, x_s_out,
head_size, head_size,
x.stride(0), x.stride(0),
x.stride(1), x.stride(1),
...@@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8( ...@@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8(
) )
_per_tensor_quant_mla_fp8_stage2[grid]( _per_tensor_quant_mla_fp8_stage2[grid](
x, x,
x_s, x_s_out,
x_q, x_q,
num_seq, num_seq,
head_size, head_size,
...@@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8( ...@@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8(
BLOCK_SIZE, BLOCK_SIZE,
) )
return x_q, x_s return x_q, x_s_out
def scaled_fp8_quant( def scaled_fp8_quant(
......
...@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import add_prefix, is_cuda, is_hip from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module): ...@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=input_ids.device,
)
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
else: else:
...@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module): ...@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module):
residual = None residual = None
hidden_states, residual = self.decoder( hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual, zero_allocator
) )
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
......
...@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder ...@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__) ...@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__)
class AttnForwardMethod(IntEnum): class AttnForwardMethod(IntEnum):
# Use multi-head attention # Use multi-head attention
MHA = auto() MHA = auto()
...@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
if hidden_states.shape[0] == 0: if hidden_states.shape[0] == 0:
assert ( assert (
...@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module):
positions, hidden_states, forward_batch positions, hidden_states, forward_batch
) )
else: else:
return self.forward_absorb(positions, hidden_states, forward_batch) return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
else: else:
return self.forward_absorb(positions, hidden_states, forward_batch) return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
def forward_normal( def forward_normal(
self, self,
...@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
q_len = hidden_states.shape[0] q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty( q_input = hidden_states.new_empty(
...@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif self.w_kc.dtype == torch.float8_e4m3fn: elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), q_nope.transpose(0, 1),
zero_allocator.allocate(1),
) )
q_nope_out = bmm_fp8( q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
...@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif self.w_vc.dtype == torch.float8_e4m3fn: elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), attn_output.transpose(0, 1),
zero_allocator.allocate(1),
) )
attn_bmm_output = bmm_fp8( attn_bmm_output = bmm_fp8(
attn_output_val, attn_output_val,
...@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
enable_rope_fusion = ( enable_rope_fusion = (
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1" os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
...@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
elif self.w_kc.dtype == torch.float8_e4m3fn: elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn q_nope.transpose(0, 1),
zero_allocator.allocate(1),
dtype=torch.float8_e4m3fn,
) )
q_nope_out = bmm_fp8( q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
...@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
elif self.w_vc.dtype == torch.float8_e4m3fn: elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn attn_output.transpose(0, 1),
zero_allocator.allocate(1),
dtype=torch.float8_e4m3fn,
) )
attn_bmm_output = bmm_fp8( attn_bmm_output = bmm_fp8(
attn_output_val, attn_output_val,
...@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED: if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input( return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual, zero_allocator
) )
elif self.info.ffn_input_mode == _FFNInputMode.FULL: elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input( return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual, zero_allocator
) )
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
if hidden_states.shape[0] == 0: if hidden_states.shape[0] == 0:
...@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
forward_batch=forward_batch, forward_batch=forward_batch,
zero_allocator=zero_allocator,
) )
# Gather # Gather
...@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
if hidden_states.shape[0] == 0: if hidden_states.shape[0] == 0:
...@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
forward_batch=forward_batch, forward_batch=forward_batch,
zero_allocator=zero_allocator,
) )
if self.attn_tp_size != 1: if self.attn_tp_size != 1:
...@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module): ...@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
zero_allocator = BumpAllocator(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size=len(self.layers) * 2,
dtype=torch.float32,
device=input_ids.device,
)
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
...@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module):
expert_distribution_recorder.set_current_layer(i) expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual, zero_allocator
) )
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
if residual is None: if residual is None:
......
...@@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config): ...@@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config):
"MistralForCausalLM", "MistralForCausalLM",
} }
return architectures[0] in default_archs return architectures[0] in default_archs
# Can be more general if it is used in multiple places (keep it simple and thus not general now)
class BumpAllocator:
def __init__(self, buffer_size: int, dtype, device):
self._buffer = torch.zeros((buffer_size,), dtype=dtype, device=device)
self._pointer = 0
def allocate(self, size: int):
assert self._pointer + size <= len(self._buffer)
output = self._buffer[self._pointer : self._pointer + size]
self._pointer += size
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment