Unverified Commit 3053a22b authored by Aleksandr Malyshev's avatar Aleksandr Malyshev Committed by GitHub
Browse files

fp8 kv cache support fix for torch.compile (#22758)


Signed-off-by: default avatarAleksandr Malyshev <maleksan@amd.com>
Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: default avatarAleksandr Malyshev <maleksan@amd.com>
Co-authored-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: default avatarGregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
parent 02d4b854
...@@ -125,7 +125,9 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -125,7 +125,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# These are used in the final Attention.forward() # These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale) layer._q_scale.copy_(q_scale)
layer._q_scale_float = q_scale layer._q_scale_float = q_scale.item() if isinstance(
q_scale, torch.Tensor) else q_scale
layer._prob_scale.copy_(prob_scale) layer._prob_scale.copy_(prob_scale)
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
or prob_scale == 1.0): or prob_scale == 1.0):
......
...@@ -361,7 +361,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -361,7 +361,7 @@ class TritonAttentionImpl(AttentionImpl):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape num_tokens, num_heads, head_size = query.shape
assert layer._q_scale == 1.0, \ assert layer._q_scale_float == 1.0, \
"A non 1.0 q_scale is not currently supported." "A non 1.0 q_scale is not currently supported."
if current_platform.is_cuda(): if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda # Skip Q quantization on ROCm and XPU, enable this on cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment