Commit 8f30468c authored by zhuwenwen's avatar zhuwenwen
Browse files

update q_quant dtype

parent c7b0d0d4
......@@ -1253,7 +1253,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
)
else:
q_tensor = torch.randn(q.shape[0], num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, dtype=q.dtype, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=kv_cache_dtype_str, device=q.device)
q_quant = torch.empty_like(q_tensor, dtype=torch.float8_e4m3fn, device=q.device)
q_scale = torch.empty(q.shape[0], dtype=torch.float32, device=q.device)
fuse_rmsnorm_rope_quant_qkv(
positions[:num_actual_toks, ...],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment