Unverified Commit a38376fa authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor attention into multiple stages (#6477)

parent 7a5e6ce1
......@@ -677,44 +677,94 @@ class DeepseekV2AttentionMLA(nn.Module):
else:
return _dispatch_mla_subtype()
def op_prepare(self, state):
state.attn_intermediate_state = self.forward_prepare(
positions=state.positions,
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
forward_batch=state.forward_batch,
zero_allocator=state.zero_allocator,
)
def op_core(self, state):
state.hidden_states_after_attn = self.forward_core(
state.pop("attn_intermediate_state")
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
):
s = self.forward_prepare(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
return self.forward_core(s)
def forward_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
):
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states
return hidden_states, None, forward_batch, None
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
if attn_forward_method == AttnForwardMethod.MHA:
return self.forward_normal(positions, hidden_states, forward_batch)
inner_state = self.forward_normal_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
return self.forward_normal_chunked_kv(
positions, hidden_states, forward_batch
inner_state = self.forward_normal_chunked_kv_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb(
inner_state = self.forward_absorb_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope(
positions, hidden_states, forward_batch
inner_state = self.forward_absorb_fused_mla_rope_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
else:
raise NotImplementedError
return None, attn_forward_method, forward_batch, inner_state
def forward_normal(
def forward_core(self, intermediate_state):
hidden_states, attn_forward_method, forward_batch, inner_state = (
intermediate_state
)
if inner_state is None:
return hidden_states
if attn_forward_method == AttnForwardMethod.MHA:
return self.forward_normal_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
return self.forward_normal_chunked_kv_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope_core(*inner_state)
else:
raise NotImplementedError
def forward_normal_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
zero_allocator: BumpAllocator,
):
if self.q_lora_rank is not None:
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
......@@ -749,18 +799,22 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
return q, k, v, forward_batch
def forward_normal_core(self, q, k, v, forward_batch):
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
def forward_absorb(
def forward_absorb_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if self.q_lora_rank is not None:
......@@ -829,6 +883,11 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
def forward_absorb_core(
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
):
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
......@@ -881,13 +940,13 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
def forward_absorb_fused_mla_rope(
def forward_absorb_fused_mla_rope_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
):
enable_rope_fusion = (
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
)
......@@ -976,6 +1035,44 @@ class DeepseekV2AttentionMLA(nn.Module):
)
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
return (
q_input,
key_cache_buf,
val_cache_buf,
attn_output,
kv_indptr,
kv_indices,
k_pe_output,
cos_sin_cache,
positions,
attn_logits,
num_kv_split,
sm_scale,
enable_rope_fusion,
k_input,
forward_batch,
zero_allocator,
)
def forward_absorb_fused_mla_rope_core(
self,
q_input,
key_cache_buf,
val_cache_buf,
attn_output,
kv_indptr,
kv_indices,
k_pe_output,
cos_sin_cache,
positions,
attn_logits,
num_kv_split,
sm_scale,
enable_rope_fusion,
k_input,
forward_batch,
zero_allocator,
):
decode_attention_fwd_grouped_rope(
q_input,
key_cache_buf,
......@@ -1082,12 +1179,13 @@ class DeepseekV2AttentionMLA(nn.Module):
return accum_output
def forward_normal_chunked_kv(
def forward_normal_chunked_kv_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
zero_allocator: BumpAllocator,
):
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
# To avoid this, we split the kv cache into chunks and process them one after another.
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
......@@ -1130,6 +1228,9 @@ class DeepseekV2AttentionMLA(nn.Module):
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
return q, k, v, forward_batch
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
# Do mha for extended part without prefix
forward_batch.set_attn_attend_prefix_cache(False)
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
......@@ -1283,14 +1384,6 @@ class DeepseekV2DecoderLayer(nn.Module):
)
)
def op_attn(self, state):
state.hidden_states_after_attn = self.self_attn(
positions=state.positions,
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
forward_batch=state.forward_batch,
zero_allocator=state.zero_allocator,
)
def op_comm_prepare_mlp(self, state):
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
self.layer_communicator.prepare_mlp(
......
......@@ -7,7 +7,8 @@ def compute_layer_operations(
if not layer.is_layer_sparse:
return [
layer.op_comm_prepare_attn,
layer.op_attn,
layer.self_attn.op_prepare,
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.op_mlp,
layer.op_comm_postprocess_layer,
......@@ -16,7 +17,8 @@ def compute_layer_operations(
# Will add TBO operation orders here
return [
layer.op_comm_prepare_attn,
layer.op_attn,
layer.self_attn.op_prepare,
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.mlp.op_gate,
layer.mlp.op_shared_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment