Commit 4c45697e authored by shangxl's avatar shangxl
Browse files

fa3 support qwen.

parent a55cb8b2
...@@ -446,7 +446,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -446,7 +446,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim) k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
num_draft_tokens = self.num_draft_tokens if self.num_draft_tokens is not None else 0
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz): torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn): if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
...@@ -458,7 +458,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -458,7 +458,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q, reshape_q,
k_cache_reshaped, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs], self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), (forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
k_scale, k_scale,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
...@@ -468,7 +468,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -468,7 +468,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q, reshape_q,
k_cache_reshaped, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs], self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), (forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
) )
...@@ -487,9 +487,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -487,9 +487,6 @@ class DCUMLABackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
sinks=None, sinks=None,
): ):
if save_kv_cache:
return self.forward_decode(q,k,v,layer,forward_batch, save_kv_cache)
if ( if (
forward_batch.forward_mode == ForwardMode.EXTEND forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
...@@ -517,7 +514,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -517,7 +514,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim) k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
num_draft_tokens = self.num_draft_tokens if self.num_draft_tokens is not None else 0
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz): torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn): if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
...@@ -529,7 +526,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -529,7 +526,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q, reshape_q,
k_cache_reshaped, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs], self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), (forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
k_scale, k_scale,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
...@@ -539,7 +536,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -539,7 +536,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q, reshape_q,
k_cache_reshaped, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs], self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), (forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
) )
......
...@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid self.is_hybrid = model_runner.is_hybrid
self.k_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
self.v_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
if self.is_hybrid: if self.is_hybrid:
self.full_to_swa_index_mapping = ( self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping model_runner.token_to_kv_pool.full_to_swa_index_mapping
...@@ -672,10 +674,16 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -672,10 +674,16 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
# if not self.use_mla:
if k_rope is None:
if not self.use_mla: if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else: else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, layer,
...@@ -694,7 +702,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -694,7 +702,8 @@ class FlashAttentionBackend(AttentionBackend):
layer.sliding_window_size is not None and layer.sliding_window_size > -1 layer.sliding_window_size is not None and layer.sliding_window_size > -1
) )
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1) window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
k_descale, v_descale = None, None # k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case, # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
...@@ -778,55 +787,53 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -778,55 +787,53 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k = metadata.encoder_cu_seqlens_k cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1) window_size = (-1, -1)
result = flash_attn_with_kvcache( if forward_batch.attn_attend_prefix_cache:
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), assert not get_global_server_args().disable_chunked_prefix_cache
k_cache=key_cache, # MHA for chunked prefix kv cache when running model with MLA
v_cache=value_cache, assert forward_batch.prefix_chunk_idx is not None
page_table=page_table, assert forward_batch.prefix_chunk_cu_seq_lens is not None
cache_seqlens=cache_seqlens, assert forward_batch.prefix_chunk_max_seq_lens is not None
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, chunk_idx = forward_batch.prefix_chunk_idx
max_seqlen_q=max_seqlen_q, assert chunk_idx >= 0
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal, causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
else:
if use_cascade_attn: output = flash_attn_varlen_func(
o, softmax_lse, *rest = result q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
k_cache=key_cache, cu_seqlens_q=metadata.cu_seqlens_q,
v_cache=value_cache, cu_seqlens_k=metadata.cu_seqlens_q,
page_table=self.forward_metadata_spec_decode_expand.page_table, max_seqlen_q=metadata.max_seq_len_q,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, max_seqlen_k=metadata.max_seq_len_q,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False, causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=forward_batch.mha_return_lse,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
o, _ = merge_state_v2_wrapper( if forward_batch.mha_return_lse:
o, output, lse, *rest = output
softmax_lse.T.contiguous(), lse = torch.transpose(lse, 0, 1).contiguous()
o_expand, return output, lse
softmax_lse_expand.T.contiguous(), return output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
)
else:
o = result
else: else:
if ( if (
forward_batch.attn_attend_prefix_cache is not None forward_batch.attn_attend_prefix_cache is not None
...@@ -855,6 +862,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -855,6 +862,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx], max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False, causal=False,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
**kwargs, **kwargs,
) )
...@@ -869,6 +878,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -869,6 +878,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k=metadata.max_seq_len_q, max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=True,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=forward_batch.mha_return_lse, return_softmax_lse=forward_batch.mha_return_lse,
**kwargs, **kwargs,
) )
...@@ -978,10 +989,16 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -978,10 +989,16 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
# if not self.use_mla:
if k_rope is None:
if not self.use_mla: if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else: else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, layer,
...@@ -1023,7 +1040,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1023,7 +1040,8 @@ class FlashAttentionBackend(AttentionBackend):
if sinks is not None: if sinks is not None:
kwargs["sinks"] = sinks kwargs["sinks"] = sinks
k_descale, v_descale = None, None # k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case. # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
...@@ -1037,7 +1055,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1037,7 +1055,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
if not self.use_mla: if not self.use_mla:
# Do multi-head attention # Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id layer.layer_id
) )
...@@ -1089,26 +1106,33 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1089,26 +1106,33 @@ class FlashAttentionBackend(AttentionBackend):
**kwargs, **kwargs,
) )
else: else:
cu_seqlens_q = metadata.cu_seqlens_q
max_seqlen_q = metadata.max_seq_len_q
page_table = metadata.page_table page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q cache_seqlens = metadata.cache_seqlens_int32
q_reshaped = q.contiguous().view( key_cache = key_cache.view(
-1, layer.tp_q_head_num, layer.head_dim -1, self.page_size, layer.tp_k_head_num, layer.head_dim
) )
value_cache = value_cache.view(
# Default: single-token self-attention -1, self.page_size, layer.tp_v_head_num, layer.head_dim
result = flash_attn_with_kvcache( )
q=q_reshaped, if layer.is_cross_attention:
k_cache=key_cache, page_table = metadata.encoder_page_table
v_cache=value_cache, cache_seqlens = metadata.encoder_lens_int32
page_table=page_table, cu_seqlens_k = metadata.encoder_cu_seqlens_k
cache_seqlens=cache_seqlens, window_size = (-1, -1)
cu_seqlens_q=metadata.cu_seqlens_q, if max_seqlen_q > 1:
cu_seqlens_k_new=cu_seqlens_k, result = flash_attn_varlen_func(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal, causal=True,
window_size=window_size, window_size=window_size,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
...@@ -1117,36 +1141,26 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1117,36 +1141,26 @@ class FlashAttentionBackend(AttentionBackend):
num_splits=self.num_splits, num_splits=self.num_splits,
**kwargs, **kwargs,
) )
if use_cascade_attn: else:
o, softmax_lse, *rest = result result = flash_attn_with_kvcache(
o_expand, softmax_lse_expand, *rest_expand = ( q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache, k_cache=key_cache,
v_cache=value_cache, v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table, page_table=page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, cache_seqlens=cache_seqlens,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False, causal=True,
window_size=window_size, window_size=window_size,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits, num_splits=self.num_splits,
**kwargs, **kwargs,
) )
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result o = result
else: else:
# Do absorbed multi-latent attention # Do absorbed multi-latent attention
......
...@@ -42,8 +42,8 @@ def flash_attn_with_kvcache( ...@@ -42,8 +42,8 @@ def flash_attn_with_kvcache(
): ):
return flash_attn_with_kvcache_interface( return flash_attn_with_kvcache_interface(
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]), q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache, k_cache=k_cache.view(q.dtype),
v_cache=v_cache, v_cache=v_cache.view(q.dtype),
block_table=page_table, block_table=page_table,
cache_seqlens=cache_seqlens, cache_seqlens=cache_seqlens,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
...@@ -83,8 +83,8 @@ def flash_attn_varlen_func( ...@@ -83,8 +83,8 @@ def flash_attn_varlen_func(
): ):
return flash_attn_varlen_func_interface( return flash_attn_varlen_func_interface(
q=q, q=q,
k=k, k=k.view(q.dtype),
v=v, v=v.view(q.dtype),
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
...@@ -92,4 +92,5 @@ def flash_attn_varlen_func( ...@@ -92,4 +92,5 @@ def flash_attn_varlen_func(
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=causal, causal=causal,
return_attn_probs=return_softmax_lse, return_attn_probs=return_softmax_lse,
softcap=softcap,
) )
\ No newline at end of file
...@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens = seq_lens_tensor self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu self.seq_lens_cpu = seq_lens_cpu
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.loc_tensor = torch.tensor([-1], device=self.device)
# Allocate memory # Allocate memory
out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend( out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
......
...@@ -356,7 +356,8 @@ def alloc_for_extend( ...@@ -356,7 +356,8 @@ def alloc_for_extend(
else: else:
# Paged allocation - build last_loc # Paged allocation - build last_loc
last_loc = [ last_loc = [
(t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device)) # (t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
(t[-1:] if len(t) > 0 else batch.loc_tensor)
for t in prefix_tensors for t in prefix_tensors
] ]
out_cache_loc = alloc_paged_token_slots_extend( out_cache_loc = alloc_paged_token_slots_extend(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment