Commit 10cdc93d authored by zhuwenwen's avatar zhuwenwen
Browse files

update mla layout

parent 38571cde
...@@ -194,7 +194,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, ...@@ -194,7 +194,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type, TypeVar) Type, TypeVar)
import torch import torch
import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
...@@ -1048,6 +1048,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1048,6 +1048,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
functools.partial(flash_attn_varlen_func, functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version) fa_version=self.vllm_flash_attn_version)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
...@@ -1098,6 +1100,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1098,6 +1100,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# we currently do not have quantized bmm's which are needed for # we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low # the bmm's in 16-bit, the extra memory overhead of this is fairly low
if self.use_llama_nn and isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod):
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
else:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment