Unverified Commit 6cb32ef9 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support Triton fp8 e5m2 kv cache (#1286)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 761b2ceb
......@@ -128,7 +128,7 @@ def _fwd_kernel(
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk += tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
......@@ -140,7 +140,7 @@ def _fwd_kernel(
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale
if logit_cap > 0:
......@@ -276,9 +276,17 @@ def extend_attention_fwd(
BLOCK_DV = Lv
if CUDA_CAPABILITY[0] >= 9:
BLOCK_M, BLOCK_N = (128, 64)
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
......
......@@ -348,13 +348,7 @@ class ModelRunner:
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
logger.warning(
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
)
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = torch.float8_e5m2
self.kv_cache_dtype = torch.float8_e5m2
else:
raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment