Commit f37b05b5 authored by linhai1's avatar linhai1
Browse files

modify codes with performance issues.

parent 59b01a00
...@@ -89,6 +89,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -89,6 +89,7 @@ class DCUMLABackend(AttentionBackend):
self.q_data_type = model_runner.dtype self.q_data_type = model_runner.dtype
self.device = model_runner.device self.device = model_runner.device
self.k_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
...@@ -388,26 +389,20 @@ class DCUMLABackend(AttentionBackend): ...@@ -388,26 +389,20 @@ class DCUMLABackend(AttentionBackend):
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim) k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
if self.data_type in ( if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
getattr(torch, "float8_e4m3fn", None), torch.float8_e5m2, torch.float8_e5m2fnuz):
getattr(torch, "float8_e4m3fnuz", None), if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz or \
k_cache_reshaped.dtype == torch.float8_e4m3fn:
kv_cache_dtype="fp8_e4m3" kv_cache_dtype="fp8_e4m3"
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz or \ else:
k_cache_reshaped.dtype == torch.float8_e5m2:
kv_cache_dtype="fp8_e5m2" kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device) k_scale = layer.k_scale if layer.k_scale is not None else self.k_scale
o = self._call_fp8_decode( o = self._call_fp8_decode(
reshape_q, reshape_q,
k_cache_reshaped, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs], self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), forward_batch.seq_lens.to(torch.int32),
layer.scaling, layer.scaling,
k_scale.to(torch.float32), k_scale,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
) )
else: else:
...@@ -460,26 +455,20 @@ class DCUMLABackend(AttentionBackend): ...@@ -460,26 +455,20 @@ class DCUMLABackend(AttentionBackend):
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim) k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
if self.data_type in ( if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
getattr(torch, "float8_e4m3fn", None), torch.float8_e5m2, torch.float8_e5m2fnuz):
getattr(torch, "float8_e4m3fnuz", None), if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz or \
k_cache_reshaped.dtype == torch.float8_e4m3fn:
kv_cache_dtype="fp8_e4m3" kv_cache_dtype="fp8_e4m3"
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz or \ else:
k_cache_reshaped.dtype == torch.float8_e5m2:
kv_cache_dtype="fp8_e5m2" kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device) k_scale = layer.k_scale if layer.k_scale is not None else self.k_scale
o = self._call_fp8_decode( o = self._call_fp8_decode(
reshape_q, reshape_q,
k_cache_reshaped, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs], self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), (forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
k_scale.to(torch.float32), k_scale,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
) )
else: else:
......
...@@ -695,7 +695,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -695,7 +695,6 @@ class FlashAttentionBackend(AttentionBackend):
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case, # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys. # 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
data_dtype = q.dtype
if ( if (
self.kv_cache_dtype_str != "auto" self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256 and layer.head_dim <= 256
...@@ -705,7 +704,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -705,7 +704,7 @@ class FlashAttentionBackend(AttentionBackend):
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape) k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape) v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype) # q = q.to(self.kv_cache_dtype)
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
causal = True causal = True
...@@ -830,8 +829,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -830,8 +829,6 @@ class FlashAttentionBackend(AttentionBackend):
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
): ):
k_descale = k_descale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=q.device)
v_descale = v_descale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=q.device)
# Do multi-head attention with chunked prefix cache # Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache: if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache assert not get_global_server_args().disable_chunked_prefix_cache
...@@ -845,9 +842,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -845,9 +842,9 @@ class FlashAttentionBackend(AttentionBackend):
assert forward_batch.mha_return_lse assert forward_batch.mha_return_lse
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim).to(data_dtype), q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=(k.view(-1, layer.tp_k_head_num, layer.head_dim) * k_descale).to(data_dtype), k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=(v.view(-1, layer.tp_k_head_num, layer.v_head_dim) * v_descale).to(data_dtype), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx], cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q, max_seqlen_q=metadata.max_seq_len_q,
...@@ -859,9 +856,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -859,9 +856,9 @@ class FlashAttentionBackend(AttentionBackend):
) )
else: else:
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim).to(data_dtype), q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=(k.view(-1, layer.tp_k_head_num, layer.head_dim) * k_descale).to(data_dtype), k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=(v.view(-1, layer.tp_k_head_num, layer.v_head_dim) * v_descale).to(data_dtype), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q, cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q, max_seqlen_q=metadata.max_seq_len_q,
......
...@@ -1301,6 +1301,15 @@ class MLATokenToKVPool(KVCache): ...@@ -1301,6 +1301,15 @@ class MLATokenToKVPool(KVCache):
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype) return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
return self.kv_buffer[layer_id - self.start_layer] return self.kv_buffer[layer_id - self.start_layer]
def get_key_buffer_DeepSeekV2(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype and self.dtype not in (
torch.float8_e5m2, torch.float8_e4m3fn):
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
return self.kv_buffer[layer_id - self.start_layer], self.dtype
def get_value_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer) self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
......
...@@ -1624,12 +1624,14 @@ class ModelRunner: ...@@ -1624,12 +1624,14 @@ class ModelRunner:
self.kv_cache_dtype = self.dtype self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2": elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if _is_hip: # Using natively supported format if _is_hip: # Using natively supported format
self.kv_cache_dtype = torch.float8_e5m2fnuz # self.kv_cache_dtype = torch.float8_e5m2fnuz
self.kv_cache_dtype = torch.float8_e5m2
else: else:
self.kv_cache_dtype = torch.float8_e5m2 self.kv_cache_dtype = torch.float8_e5m2
elif self.server_args.kv_cache_dtype == "fp8_e4m3": elif self.server_args.kv_cache_dtype == "fp8_e4m3":
if _is_hip: # Using natively supported format if _is_hip: # Using natively supported format
self.kv_cache_dtype = torch.float8_e4m3fnuz # self.kv_cache_dtype = torch.float8_e4m3fnuz
self.kv_cache_dtype = torch.float8_e4m3fn
else: else:
self.kv_cache_dtype = torch.float8_e4m3fn self.kv_cache_dtype = torch.float8_e4m3fn
elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"): elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
......
...@@ -2294,12 +2294,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -2294,12 +2294,13 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch.set_prefix_chunk_idx(i) forward_batch.set_prefix_chunk_idx(i)
# Fetch latent cache from memory pool with precomputed chunked kv indices # Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer( latent_cache_buf, dtype = forward_batch.token_to_kv_pool.get_key_buffer_DeepSeekV2(
self.attn_mha.layer_id self.attn_mha.layer_id
).to(q.dtype) )
latent_cache = ( latent_cache = (
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]] latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
.contiguous() .contiguous()
.view(dtype)
.to(q.dtype) .to(q.dtype)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment