Unverified Commit 9db80253 authored by Rain Jiang's avatar Rain Jiang Committed by GitHub
Browse files

support fp8 kvcache for hybrid attn backend on GPT-OSS (#9783)


Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 598c0bc1
...@@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module):
return ans return ans
def _enable_fused_set_kv_buffer(): def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
return _is_cuda """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
# TODO maybe move to a model-common utils # TODO maybe move to a model-common utils
...@@ -341,7 +342,7 @@ class GptOssAttention(nn.Module): ...@@ -341,7 +342,7 @@ class GptOssAttention(nn.Module):
layer=self.attn, layer=self.attn,
forward_batch=forward_batch, forward_batch=forward_batch,
) )
if _enable_fused_set_kv_buffer() if _enable_fused_set_kv_buffer(forward_batch)
else None else None
), ),
) )
...@@ -355,7 +356,7 @@ class GptOssAttention(nn.Module): ...@@ -355,7 +356,7 @@ class GptOssAttention(nn.Module):
attn_output = self.attn( attn_output = self.attn(
*inner_state, *inner_state,
sinks=self.sinks, sinks=self.sinks,
save_kv_cache=not _enable_fused_set_kv_buffer(), save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
) )
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment