Commit eb4b015f authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_linhai' into 'v0.5.4_dev'

modify codes with performance issues.

See merge request OpenDAS/sglang!20
parents cbf7a3e3 f37b05b5
......@@ -89,6 +89,7 @@ class DCUMLABackend(AttentionBackend):
self.q_data_type = model_runner.dtype
self.device = model_runner.device
self.k_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
self.max_context_len = model_runner.model_config.context_len
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
......@@ -388,26 +389,20 @@ class DCUMLABackend(AttentionBackend):
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
if self.data_type in (
getattr(torch, "float8_e4m3fn", None),
getattr(torch, "float8_e4m3fnuz", None),
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz or \
k_cache_reshaped.dtype == torch.float8_e4m3fn:
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
kv_cache_dtype="fp8_e4m3"
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz or \
k_cache_reshaped.dtype == torch.float8_e5m2:
else:
kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
k_scale = layer.k_scale if layer.k_scale is not None else self.k_scale
o = self._call_fp8_decode(
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32),
layer.scaling,
k_scale.to(torch.float32),
k_scale,
kv_cache_dtype=kv_cache_dtype,
)
else:
......@@ -460,26 +455,20 @@ class DCUMLABackend(AttentionBackend):
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
if self.data_type in (
getattr(torch, "float8_e4m3fn", None),
getattr(torch, "float8_e4m3fnuz", None),
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz or \
k_cache_reshaped.dtype == torch.float8_e4m3fn:
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
kv_cache_dtype="fp8_e4m3"
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz or \
k_cache_reshaped.dtype == torch.float8_e5m2:
else:
kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
k_scale = layer.k_scale if layer.k_scale is not None else self.k_scale
o = self._call_fp8_decode(
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling,
k_scale.to(torch.float32),
k_scale,
kv_cache_dtype=kv_cache_dtype,
)
else:
......
......@@ -695,7 +695,6 @@ class FlashAttentionBackend(AttentionBackend):
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
data_dtype = q.dtype
if (
self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256
......@@ -705,7 +704,7 @@ class FlashAttentionBackend(AttentionBackend):
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
# q = q.to(self.kv_cache_dtype)
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
causal = True
......@@ -830,8 +829,6 @@ class FlashAttentionBackend(AttentionBackend):
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
k_descale = k_descale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=q.device)
v_descale = v_descale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=q.device)
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache
......@@ -845,9 +842,9 @@ class FlashAttentionBackend(AttentionBackend):
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim).to(data_dtype),
k=(k.view(-1, layer.tp_k_head_num, layer.head_dim) * k_descale).to(data_dtype),
v=(v.view(-1, layer.tp_k_head_num, layer.v_head_dim) * v_descale).to(data_dtype),
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
......@@ -859,9 +856,9 @@ class FlashAttentionBackend(AttentionBackend):
)
else:
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim).to(data_dtype),
k=(k.view(-1, layer.tp_k_head_num, layer.head_dim) * k_descale).to(data_dtype),
v=(v.view(-1, layer.tp_k_head_num, layer.v_head_dim) * v_descale).to(data_dtype),
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
......
......@@ -1301,6 +1301,15 @@ class MLATokenToKVPool(KVCache):
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
return self.kv_buffer[layer_id - self.start_layer]
def get_key_buffer_DeepSeekV2(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype and self.dtype not in (
torch.float8_e5m2, torch.float8_e4m3fn):
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
return self.kv_buffer[layer_id - self.start_layer], self.dtype
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
......
......@@ -1624,12 +1624,14 @@ class ModelRunner:
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if _is_hip: # Using natively supported format
self.kv_cache_dtype = torch.float8_e5m2fnuz
# self.kv_cache_dtype = torch.float8_e5m2fnuz
self.kv_cache_dtype = torch.float8_e5m2
else:
self.kv_cache_dtype = torch.float8_e5m2
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
if _is_hip: # Using natively supported format
self.kv_cache_dtype = torch.float8_e4m3fnuz
# self.kv_cache_dtype = torch.float8_e4m3fnuz
self.kv_cache_dtype = torch.float8_e4m3fn
else:
self.kv_cache_dtype = torch.float8_e4m3fn
elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
......
......@@ -2294,12 +2294,13 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch.set_prefix_chunk_idx(i)
# Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
latent_cache_buf, dtype = forward_batch.token_to_kv_pool.get_key_buffer_DeepSeekV2(
self.attn_mha.layer_id
).to(q.dtype)
)
latent_cache = (
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
.contiguous()
.view(dtype)
.to(q.dtype)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment