"vscode:/vscode.git/clone" did not exist on "df214215254fd1a50c137ed77f007be643100a1b"
Unverified Commit 6cb32ef9 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support Triton fp8 e5m2 kv cache (#1286)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 761b2ceb
...@@ -128,7 +128,7 @@ def _fwd_kernel( ...@@ -128,7 +128,7 @@ def _fwd_kernel(
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_kpe = ( offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs offs_kv_loc[None, :] * stride_buf_kbs
...@@ -140,7 +140,7 @@ def _fwd_kernel( ...@@ -140,7 +140,7 @@ def _fwd_kernel(
mask=mask_n[None, :], mask=mask_n[None, :],
other=0.0, other=0.0,
) )
qk += tl.dot(qpe, kpe) qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale qk *= sm_scale
if logit_cap > 0: if logit_cap > 0:
...@@ -276,9 +276,17 @@ def extend_attention_fwd( ...@@ -276,9 +276,17 @@ def extend_attention_fwd(
BLOCK_DV = Lv BLOCK_DV = Lv
if CUDA_CAPABILITY[0] >= 9: if CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64) BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif CUDA_CAPABILITY[0] >= 8: elif CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64) if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
else: else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
......
...@@ -348,12 +348,6 @@ class ModelRunner: ...@@ -348,12 +348,6 @@ class ModelRunner:
if self.server_args.kv_cache_dtype == "auto": if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2": elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
logger.warning(
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
)
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = torch.float8_e5m2 self.kv_cache_dtype = torch.float8_e5m2
else: else:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment