Unverified Commit 9db80253 authored by Rain Jiang's avatar Rain Jiang Committed by GitHub
Browse files

support fp8 kvcache for hybrid attn backend on GPT-OSS (#9783)


Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 598c0bc1
......@@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module):
return ans
def _enable_fused_set_kv_buffer():
return _is_cuda
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
# TODO maybe move to a model-common utils
......@@ -341,7 +342,7 @@ class GptOssAttention(nn.Module):
layer=self.attn,
forward_batch=forward_batch,
)
if _enable_fused_set_kv_buffer()
if _enable_fused_set_kv_buffer(forward_batch)
else None
),
)
......@@ -355,7 +356,7 @@ class GptOssAttention(nn.Module):
attn_output = self.attn(
*inner_state,
sinks=self.sinks,
save_kv_cache=not _enable_fused_set_kv_buffer(),
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
)
output, _ = self.o_proj(attn_output)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment