Unverified Commit 140cbb11 authored by JartX's avatar JartX Committed by GitHub
Browse files

[Bugfix] Cuda Clean up scales Kvcache fp8/int8_per_token_head (#39224)


Signed-off-by: default avatarJartX <sagformas@epdcenter.es>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 6155bbd1
...@@ -5859,6 +5859,13 @@ class GPUModelRunner( ...@@ -5859,6 +5859,13 @@ class GPUModelRunner(
layer.kv_cache = ( layer.kv_cache = (
torch.tensor([]) if isinstance(kv_cache, torch.Tensor) else [] torch.tensor([]) if isinstance(kv_cache, torch.Tensor) else []
) )
# Clean up quantized KV cache scale views
# (int8_per_token_head, fp8_per_token_head)
if hasattr(layer, "impl"):
if hasattr(layer.impl, "_k_scale_cache"):
layer.impl._k_scale_cache = None
if hasattr(layer.impl, "_v_scale_cache"):
layer.impl._v_scale_cache = None
gc.collect() gc.collect()
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment