Commit 3b121add authored by zhuwenwen's avatar zhuwenwen
Browse files

update kv_b_proj_weight

parent a1b2eff7
......@@ -1102,6 +1102,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
if self.use_llama_nn and isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod):
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
else:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment