Unverified Commit a38376fa authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor attention into multiple stages (#6477)

parent 7a5e6ce1
...@@ -677,44 +677,94 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -677,44 +677,94 @@ class DeepseekV2AttentionMLA(nn.Module):
else: else:
return _dispatch_mla_subtype() return _dispatch_mla_subtype()
def op_prepare(self, state):
state.attn_intermediate_state = self.forward_prepare(
positions=state.positions,
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
forward_batch=state.forward_batch,
zero_allocator=state.zero_allocator,
)
def op_core(self, state):
state.hidden_states_after_attn = self.forward_core(
state.pop("attn_intermediate_state")
)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
) -> torch.Tensor: ):
s = self.forward_prepare(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
return self.forward_core(s)
def forward_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
):
if hidden_states.shape[0] == 0: if hidden_states.shape[0] == 0:
assert ( assert (
not self.o_proj.reduce_results not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs" ), "short-circuiting allreduce will lead to hangs"
return hidden_states return hidden_states, None, forward_batch, None
attn_forward_method = self.dispatch_attn_forward_method(forward_batch) attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
if attn_forward_method == AttnForwardMethod.MHA: if attn_forward_method == AttnForwardMethod.MHA:
return self.forward_normal(positions, hidden_states, forward_batch) inner_state = self.forward_normal_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV: elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
return self.forward_normal_chunked_kv( inner_state = self.forward_normal_chunked_kv_prepare(
positions, hidden_states, forward_batch positions, hidden_states, forward_batch, zero_allocator
) )
elif attn_forward_method == AttnForwardMethod.MLA: elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb( inner_state = self.forward_absorb_prepare(
positions, hidden_states, forward_batch, zero_allocator positions, hidden_states, forward_batch, zero_allocator
) )
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope( inner_state = self.forward_absorb_fused_mla_rope_prepare(
positions, hidden_states, forward_batch positions, hidden_states, forward_batch, zero_allocator
) )
else: else:
raise NotImplementedError raise NotImplementedError
return None, attn_forward_method, forward_batch, inner_state
def forward_core(self, intermediate_state):
hidden_states, attn_forward_method, forward_batch, inner_state = (
intermediate_state
)
if inner_state is None:
return hidden_states
if attn_forward_method == AttnForwardMethod.MHA:
return self.forward_normal_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
return self.forward_normal_chunked_kv_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope_core(*inner_state)
else:
raise NotImplementedError
def forward_normal( def forward_normal_prepare(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: zero_allocator: BumpAllocator,
):
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
...@@ -749,18 +799,22 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -749,18 +799,22 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
) )
return q, k, v, forward_batch
def forward_normal_core(self, q, k, v, forward_batch):
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim) attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
def forward_absorb( def forward_absorb_prepare(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
) -> torch.Tensor: ):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
...@@ -829,6 +883,11 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -829,6 +883,11 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1) q_nope_out = q_nope_out.transpose(0, 1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
def forward_absorb_core(
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
):
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer": if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
attn_output = self.attn_mqa( attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
...@@ -881,13 +940,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -881,13 +940,13 @@ class DeepseekV2AttentionMLA(nn.Module):
return output return output
def forward_absorb_fused_mla_rope( def forward_absorb_fused_mla_rope_prepare(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
) -> torch.Tensor: ):
enable_rope_fusion = ( enable_rope_fusion = (
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1" os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
) )
...@@ -976,6 +1035,44 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -976,6 +1035,44 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank] val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
return (
q_input,
key_cache_buf,
val_cache_buf,
attn_output,
kv_indptr,
kv_indices,
k_pe_output,
cos_sin_cache,
positions,
attn_logits,
num_kv_split,
sm_scale,
enable_rope_fusion,
k_input,
forward_batch,
zero_allocator,
)
def forward_absorb_fused_mla_rope_core(
self,
q_input,
key_cache_buf,
val_cache_buf,
attn_output,
kv_indptr,
kv_indices,
k_pe_output,
cos_sin_cache,
positions,
attn_logits,
num_kv_split,
sm_scale,
enable_rope_fusion,
k_input,
forward_batch,
zero_allocator,
):
decode_attention_fwd_grouped_rope( decode_attention_fwd_grouped_rope(
q_input, q_input,
key_cache_buf, key_cache_buf,
...@@ -1082,12 +1179,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1082,12 +1179,13 @@ class DeepseekV2AttentionMLA(nn.Module):
return accum_output return accum_output
def forward_normal_chunked_kv( def forward_normal_chunked_kv_prepare(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: zero_allocator: BumpAllocator,
):
# In normal mha, the k and v tensors will become overly large when the prefix length is long. # In normal mha, the k and v tensors will become overly large when the prefix length is long.
# To avoid this, we split the kv cache into chunks and process them one after another. # To avoid this, we split the kv cache into chunks and process them one after another.
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead. # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
...@@ -1130,6 +1228,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1130,6 +1228,9 @@ class DeepseekV2AttentionMLA(nn.Module):
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
) )
return q, k, v, forward_batch
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
# Do mha for extended part without prefix # Do mha for extended part without prefix
forward_batch.set_attn_attend_prefix_cache(False) forward_batch.set_attn_attend_prefix_cache(False)
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
...@@ -1283,14 +1384,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1283,14 +1384,6 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
) )
def op_attn(self, state):
state.hidden_states_after_attn = self.self_attn(
positions=state.positions,
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
forward_batch=state.forward_batch,
zero_allocator=state.zero_allocator,
)
def op_comm_prepare_mlp(self, state): def op_comm_prepare_mlp(self, state):
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = ( state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
self.layer_communicator.prepare_mlp( self.layer_communicator.prepare_mlp(
......
...@@ -7,7 +7,8 @@ def compute_layer_operations( ...@@ -7,7 +7,8 @@ def compute_layer_operations(
if not layer.is_layer_sparse: if not layer.is_layer_sparse:
return [ return [
layer.op_comm_prepare_attn, layer.op_comm_prepare_attn,
layer.op_attn, layer.self_attn.op_prepare,
layer.self_attn.op_core,
layer.op_comm_prepare_mlp, layer.op_comm_prepare_mlp,
layer.op_mlp, layer.op_mlp,
layer.op_comm_postprocess_layer, layer.op_comm_postprocess_layer,
...@@ -16,7 +17,8 @@ def compute_layer_operations( ...@@ -16,7 +17,8 @@ def compute_layer_operations(
# Will add TBO operation orders here # Will add TBO operation orders here
return [ return [
layer.op_comm_prepare_attn, layer.op_comm_prepare_attn,
layer.op_attn, layer.self_attn.op_prepare,
layer.self_attn.op_core,
layer.op_comm_prepare_mlp, layer.op_comm_prepare_mlp,
layer.mlp.op_gate, layer.mlp.op_gate,
layer.mlp.op_shared_experts, layer.mlp.op_shared_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment