Commit 718337a7 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix local kv_cache_dtype_str

parent fc55a25c
......@@ -1214,6 +1214,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_q = q_ori[:num_decode_tokens]
prefill_q = q_ori[num_decode_tokens:]
if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16"
elif q.dtype == torch.bfloat16:
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
......@@ -1226,14 +1234,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale,
)
else:
if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16"
elif q.dtype == torch.bfloat16:
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype_str=="fp8_e4m3" and envs.VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA:
if has_prefill:
fused_rms_norm_rope_contiguous(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment