"torchvision/models/vscode:/vscode.git/clone" did not exist on "4d711fdc21d1ccbc4d366f7bfddb318ebb7408a0"
Commit 4c45697e authored by shangxl's avatar shangxl
Browse files

fa3 support qwen.

parent a55cb8b2
......@@ -446,7 +446,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
num_draft_tokens = self.num_draft_tokens if self.num_draft_tokens is not None else 0
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
......@@ -458,7 +458,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32),
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
k_scale,
kv_cache_dtype=kv_cache_dtype,
......@@ -468,7 +468,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32),
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
)
......@@ -487,9 +487,6 @@ class DCUMLABackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None,
sinks=None,
):
if save_kv_cache:
return self.forward_decode(q,k,v,layer,forward_batch, save_kv_cache)
if (
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
......@@ -517,7 +514,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
num_draft_tokens = self.num_draft_tokens if self.num_draft_tokens is not None else 0
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
......@@ -529,7 +526,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
k_scale,
kv_cache_dtype=kv_cache_dtype,
......@@ -539,7 +536,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
)
......
......@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid
self.k_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
self.v_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
if self.is_hybrid:
self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping
......@@ -672,10 +674,16 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
# if not self.use_mla:
if k_rope is None:
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
......@@ -694,7 +702,8 @@ class FlashAttentionBackend(AttentionBackend):
layer.sliding_window_size is not None and layer.sliding_window_size > -1
)
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
k_descale, v_descale = None, None
# k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
......@@ -778,55 +787,53 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
causal=False,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
return_softmax_lse=True,
**kwargs,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
else:
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
causal=True,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
if forward_batch.mha_return_lse:
output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous()
return output, lse
return output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
if (
forward_batch.attn_attend_prefix_cache is not None
......@@ -855,6 +862,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
**kwargs,
)
......@@ -869,6 +878,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
......@@ -978,10 +989,16 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
# if not self.use_mla:
if k_rope is None:
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
......@@ -1023,7 +1040,8 @@ class FlashAttentionBackend(AttentionBackend):
if sinks is not None:
kwargs["sinks"] = sinks
k_descale, v_descale = None, None
# k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
......@@ -1037,7 +1055,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
if not self.use_mla:
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
......@@ -1089,26 +1106,33 @@ class FlashAttentionBackend(AttentionBackend):
**kwargs,
)
else:
cu_seqlens_q = metadata.cu_seqlens_q
max_seqlen_q = metadata.max_seq_len_q
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q
q_reshaped = q.contiguous().view(
-1, layer.tp_q_head_num, layer.head_dim
cache_seqlens = metadata.cache_seqlens_int32
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
# Default: single-token self-attention
result = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
if max_seqlen_q > 1:
result = flash_attn_varlen_func(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
......@@ -1117,36 +1141,26 @@ class FlashAttentionBackend(AttentionBackend):
num_splits=self.num_splits,
**kwargs,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_reshaped,
else:
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
else:
# Do absorbed multi-latent attention
......
......@@ -42,8 +42,8 @@ def flash_attn_with_kvcache(
):
return flash_attn_with_kvcache_interface(
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache,
v_cache=v_cache,
k_cache=k_cache.view(q.dtype),
v_cache=v_cache.view(q.dtype),
block_table=page_table,
cache_seqlens=cache_seqlens,
softmax_scale=softmax_scale,
......@@ -83,8 +83,8 @@ def flash_attn_varlen_func(
):
return flash_attn_varlen_func_interface(
q=q,
k=k,
v=v,
k=k.view(q.dtype),
v=v.view(q.dtype),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
......@@ -92,4 +92,5 @@ def flash_attn_varlen_func(
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=return_softmax_lse,
softcap=softcap,
)
\ No newline at end of file
......@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu
self.extend_num_tokens = extend_num_tokens
self.loc_tensor = torch.tensor([-1], device=self.device)
# Allocate memory
out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
......
......@@ -356,7 +356,8 @@ def alloc_for_extend(
else:
# Paged allocation - build last_loc
last_loc = [
(t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
# (t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
(t[-1:] if len(t) > 0 else batch.loc_tensor)
for t in prefix_tensors
]
out_cache_loc = alloc_paged_token_slots_extend(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment