Commit 10cdc93d authored by zhuwenwen's avatar zhuwenwen
Browse files

update mla layout

parent 38571cde
......@@ -194,7 +194,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type, TypeVar)
import torch
import os
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
......@@ -1047,6 +1047,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
......@@ -1098,7 +1100,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
if self.use_llama_nn and isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod):
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
else:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
......@@ -1410,4 +1415,4 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
output[num_prefill_tokens:] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output
return output
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment