Unverified Commit ec52464d authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

MLA prefill w/o weight absorption (#2349)

parent eb0c1f53
...@@ -52,12 +52,13 @@ class AttentionBackend(ABC): ...@@ -52,12 +52,13 @@ class AttentionBackend(ABC):
v: torch.Tensor, v: torch.Tensor,
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True,
): ):
"""Run forward on an attention layer.""" """Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, forward_batch) return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
else: else:
return self.forward_extend(q, k, v, layer, forward_batch) return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
def forward_decode( def forward_decode(
self, self,
...@@ -66,6 +67,7 @@ class AttentionBackend(ABC): ...@@ -66,6 +67,7 @@ class AttentionBackend(ABC):
v: torch.Tensor, v: torch.Tensor,
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True,
): ):
"""Run a forward for decode.""" """Run a forward for decode."""
raise NotImplementedError() raise NotImplementedError()
...@@ -77,6 +79,7 @@ class AttentionBackend(ABC): ...@@ -77,6 +79,7 @@ class AttentionBackend(ABC):
v: torch.Tensor, v: torch.Tensor,
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True,
): ):
"""Run a forward for extend.""" """Run a forward for extend."""
raise NotImplementedError() raise NotImplementedError()
...@@ -165,7 +165,13 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -165,7 +165,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
return 1 return 1
def forward_extend( def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
): ):
# TODO: reuse the buffer across layers # TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
...@@ -181,9 +187,10 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -181,9 +187,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
.expand(k.shape[0], -1, -1), .expand(k.shape[0], -1, -1),
) )
forward_batch.token_to_kv_pool.set_kv_buffer( if save_kv_cache:
layer, forward_batch.out_cache_loc, k, v, k_label forward_batch.token_to_kv_pool.set_kv_buffer(
) layer, forward_batch.out_cache_loc, k, v, k_label
)
( (
start_loc, start_loc,
...@@ -212,7 +219,13 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -212,7 +219,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
return o return o
def forward_decode( def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
): ):
# During torch.compile, there is a bug in rotary_emb that causes the # During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly. # output value to have a 3D tensor shape. This reshapes the output correctly.
...@@ -242,9 +255,10 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -242,9 +255,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
.expand(k.shape[0], -1, -1), .expand(k.shape[0], -1, -1),
) )
forward_batch.token_to_kv_pool.set_kv_buffer( if save_kv_cache:
layer, forward_batch.out_cache_loc, k, v, k_label forward_batch.token_to_kv_pool.set_kv_buffer(
) layer, forward_batch.out_cache_loc, k, v, k_label
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# and set a minimum value for sparse_decode # and set a minimum value for sparse_decode
......
...@@ -221,7 +221,13 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -221,7 +221,13 @@ class FlashInferAttnBackend(AttentionBackend):
return 0 return 0
def forward_extend( def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
): ):
prefill_wrapper_paged = self.prefill_wrappers_paged[ prefill_wrapper_paged = self.prefill_wrappers_paged[
self._get_wrapper_idx(layer) self._get_wrapper_idx(layer)
...@@ -237,7 +243,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -237,7 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
if not use_ragged: if not use_ragged:
if k is not None: if k is not None:
assert v is not None assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
o = prefill_wrapper_paged.forward( o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
...@@ -270,12 +277,19 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -270,12 +277,19 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2) o, _ = merge_state(o1, s1, o2, s2)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode( def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
): ):
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
cache_loc = ( cache_loc = (
...@@ -286,7 +300,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -286,7 +300,8 @@ class FlashInferAttnBackend(AttentionBackend):
if k is not None: if k is not None:
assert v is not None assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
o = decode_wrapper.forward( o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
......
...@@ -216,16 +216,23 @@ class TorchNativeAttnBackend(AttentionBackend): ...@@ -216,16 +216,23 @@ class TorchNativeAttnBackend(AttentionBackend):
return output return output
def forward_extend( def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
): ):
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else: else:
o = torch.empty_like(q) o = torch.empty_like(q)
forward_batch.token_to_kv_pool.set_kv_buffer( if save_kv_cache:
layer, forward_batch.out_cache_loc, k, v forward_batch.token_to_kv_pool.set_kv_buffer(
) layer, forward_batch.out_cache_loc, k, v
)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
...@@ -249,7 +256,13 @@ class TorchNativeAttnBackend(AttentionBackend): ...@@ -249,7 +256,13 @@ class TorchNativeAttnBackend(AttentionBackend):
return o return o
def forward_decode( def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
): ):
# During torch.compile, there is a bug in rotary_emb that causes the # During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly. # output value to have a 3D tensor shape. This reshapes the output correctly.
...@@ -260,9 +273,10 @@ class TorchNativeAttnBackend(AttentionBackend): ...@@ -260,9 +273,10 @@ class TorchNativeAttnBackend(AttentionBackend):
else: else:
o = torch.empty_like(q) o = torch.empty_like(q)
forward_batch.token_to_kv_pool.set_kv_buffer( if save_kv_cache:
layer, forward_batch.out_cache_loc, k, v forward_batch.token_to_kv_pool.set_kv_buffer(
) layer, forward_batch.out_cache_loc, k, v
)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
......
...@@ -114,7 +114,13 @@ class TritonAttnBackend(AttentionBackend): ...@@ -114,7 +114,13 @@ class TritonAttnBackend(AttentionBackend):
return 1 return 1
def forward_extend( def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
): ):
# TODO: reuse the buffer across layers # TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
...@@ -122,9 +128,10 @@ class TritonAttnBackend(AttentionBackend): ...@@ -122,9 +128,10 @@ class TritonAttnBackend(AttentionBackend):
else: else:
o = torch.empty_like(q) o = torch.empty_like(q)
forward_batch.token_to_kv_pool.set_kv_buffer( if save_kv_cache:
layer, forward_batch.out_cache_loc, k, v forward_batch.token_to_kv_pool.set_kv_buffer(
) layer, forward_batch.out_cache_loc, k, v
)
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
self.extend_attention_fwd( self.extend_attention_fwd(
...@@ -146,7 +153,13 @@ class TritonAttnBackend(AttentionBackend): ...@@ -146,7 +153,13 @@ class TritonAttnBackend(AttentionBackend):
return o return o
def forward_decode( def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
): ):
# During torch.compile, there is a bug in rotary_emb that causes the # During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly. # output value to have a 3D tensor shape. This reshapes the output correctly.
...@@ -160,9 +173,10 @@ class TritonAttnBackend(AttentionBackend): ...@@ -160,9 +173,10 @@ class TritonAttnBackend(AttentionBackend):
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
forward_batch.token_to_kv_pool.set_kv_buffer( if save_kv_cache:
layer, forward_batch.out_cache_loc, k, v forward_batch.token_to_kv_pool.set_kv_buffer(
) layer, forward_batch.out_cache_loc, k, v
)
self.decode_attention_fwd( self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
......
...@@ -284,6 +284,9 @@ def extend_attention_fwd( ...@@ -284,6 +284,9 @@ def extend_attention_fwd(
elif Lq == 288: elif Lq == 288:
BLOCK_DMODEL = 256 BLOCK_DMODEL = 256
BLOCK_DPE = 32 BLOCK_DPE = 32
elif Lq == 192:
BLOCK_DMODEL = 128
BLOCK_DPE = 64
else: else:
BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DMODEL = triton.next_power_of_2(Lq)
BLOCK_DPE = 0 BLOCK_DPE = 0
......
...@@ -48,11 +48,13 @@ class RadixAttention(nn.Module): ...@@ -48,11 +48,13 @@ class RadixAttention(nn.Module):
self.sliding_window_size = sliding_window_size or -1 self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention self.is_cross_attention = is_cross_attention
def forward(self, q, k, v, forward_batch: ForwardBatch): def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
if k is not None: if k is not None:
# For cross-layer sharing, kv can be None # For cross-layer sharing, kv can be None
assert v is not None assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward(q, k, v, self, forward_batch) return forward_batch.attn_backend.forward(
q, k, v, self, forward_batch, save_kv_cache
)
...@@ -453,7 +453,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -453,7 +453,7 @@ class DeepseekV2AttentionMLA(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale self.scaling = self.scaling * mscale * mscale
self.attn = RadixAttention( self.attn_mqa = RadixAttention(
self.num_local_heads, self.num_local_heads,
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
self.scaling, self.scaling,
...@@ -462,6 +462,15 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -462,6 +462,15 @@ class DeepseekV2AttentionMLA(nn.Module):
v_head_dim=self.kv_lora_rank, v_head_dim=self.kv_lora_rank,
) )
self.attn_mha = RadixAttention(
self.num_local_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
v_head_dim=self.v_head_dim,
)
self.w_kc = None self.w_kc = None
self.w_vc = None self.w_vc = None
self.w_scale = None self.w_scale = None
...@@ -471,6 +480,63 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -471,6 +480,63 @@ class DeepseekV2AttentionMLA(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor:
# Use normal computation for prefill and use weight absorption for extend/decode
if (
forward_batch.forward_mode.is_extend()
and forward_batch.extend_prefix_lens.sum() == 0
):
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
def forward_normal(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope = kv[..., : self.qk_nope_head_dim]
v = kv[..., self.qk_nope_head_dim :]
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe
# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
def forward_absorb(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
q_len = hidden_states.shape[0] q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty( q_input = hidden_states.new_empty(
...@@ -508,7 +574,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -508,7 +574,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_input[..., self.kv_lora_rank :] = q_pe q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn(q_input, k_input, v_input, forward_batch) attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fn: if self.w_vc.dtype == torch.float8_e4m3fn:
...@@ -835,7 +901,6 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -835,7 +901,6 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.w_vc = w_vc.contiguous().transpose(1, 2) self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale"): if hasattr(self_attn.kv_b_proj, "weight_scale"):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale self_attn.w_scale = self_attn.kv_b_proj.weight_scale
del self_attn.kv_b_proj
EntryClass = DeepseekV2ForCausalLM EntryClass = DeepseekV2ForCausalLM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment