Commit b8412df6 authored by zhuwenwen's avatar zhuwenwen
Browse files

update MLACommonBaseImpl get_and_maybe_dequant_weights

parent 24962bed
...@@ -1097,12 +1097,15 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): ...@@ -1097,12 +1097,15 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
del eye del eye
# standardize to (output, input) # standardize to (output, input)
return dequant_weights.T return dequant_weights.T
return layer.weight return layer.weight if not envs.VLLM_USE_NN else layer.weight.T
# we currently do not have quantized bmm's which are needed for # we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low # the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T if self.use_llama_nn and isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod):
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
else:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment