Commit ee93cb70 authored by zhuwenwen's avatar zhuwenwen
Browse files

update self._q_scale

parent 06106338
......@@ -101,12 +101,13 @@ class Attention(nn.Module):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
self._k_scale = torch.ones((1), dtype=torch.float32)
self._v_scale = torch.ones((1), dtype=torch.float32)
self._q_scale = torch.ones((1), dtype=torch.float32)
else:
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment