Unverified Commit 613b197e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Remove one kernel in per_tensor_quant_mla_fp8 (#5549)

parent d58e3544
......@@ -58,10 +58,8 @@ if _is_cuda:
):
_enable_jit_deepgemm = True
logger = logging.getLogger(__name__)
if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt(
......@@ -897,16 +895,20 @@ def _per_tensor_quant_mla_fp8_stage2(
def per_tensor_quant_mla_fp8(
x: torch.Tensor, eps: float = 1e-12
x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function quantizes input values to float8 values with tensor-wise quantization
and specialized for mla absorbed case.
"""
assert x.dim() == 3, "`x` is not a 3d-tensor"
assert (
x_s_out.shape == (1,)
and x_s_out.dtype == torch.float32
and x_s_out.device == x.device
)
x_q = x.new_empty(x.size(), dtype=_fp8_type)
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
num_head, num_seq, head_size = x.shape
BLOCK_SIZE = triton.next_power_of_2(head_size)
......@@ -914,7 +916,7 @@ def per_tensor_quant_mla_fp8(
_per_tensor_quant_mla_fp8_stage1[grid](
x,
x_s,
x_s_out,
head_size,
x.stride(0),
x.stride(1),
......@@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8(
)
_per_tensor_quant_mla_fp8_stage2[grid](
x,
x_s,
x_s_out,
x_q,
num_seq,
head_size,
......@@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8(
BLOCK_SIZE,
)
return x_q, x_s
return x_q, x_s_out
def scaled_fp8_quant(
......
......@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import add_prefix, is_cuda, is_hip
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
_is_hip = is_hip()
_is_cuda = is_cuda()
......@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=input_ids.device,
)
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
......@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module):
residual = None
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
if not forward_batch.forward_mode.is_idle():
......
......@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
_is_hip = is_hip()
_is_cuda = is_cuda()
......@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__)
class AttnForwardMethod(IntEnum):
# Use multi-head attention
MHA = auto()
......@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
......@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module):
positions, hidden_states, forward_batch
)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
def forward_normal(
self,
......@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
......@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1),
zero_allocator.allocate(1),
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
......@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1),
zero_allocator.allocate(1),
)
attn_bmm_output = bmm_fp8(
attn_output_val,
......@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
enable_rope_fusion = (
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
......@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module):
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
q_nope.transpose(0, 1),
zero_allocator.allocate(1),
dtype=torch.float8_e4m3fn,
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
......@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module):
)
elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
attn_output.transpose(0, 1),
zero_allocator.allocate(1),
dtype=torch.float8_e4m3fn,
)
attn_bmm_output = bmm_fp8(
attn_output_val,
......@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
else:
raise NotImplementedError
......@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
......@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
# Gather
......@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
......@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
if self.attn_tp_size != 1:
......@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size=len(self.layers) * 2,
dtype=torch.float32,
device=input_ids.device,
)
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
......@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module):
expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
if not forward_batch.forward_mode.is_idle():
if residual is None:
......
......@@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config):
"MistralForCausalLM",
}
return architectures[0] in default_archs
# Can be more general if it is used in multiple places (keep it simple and thus not general now)
class BumpAllocator:
def __init__(self, buffer_size: int, dtype, device):
self._buffer = torch.zeros((buffer_size,), dtype=dtype, device=device)
self._pointer = 0
def allocate(self, size: int):
assert self._pointer + size <= len(self._buffer)
output = self._buffer[self._pointer : self._pointer + size]
self._pointer += size
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment