Commit 371b1251 authored by zhuwenwen's avatar zhuwenwen
Browse files

add fa pad

parent 1863c926
......@@ -607,6 +607,7 @@ class QKVParallelLinear(ColumnParallelLinear):
params_dtype=params_dtype,
quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def weight_loader(self,
param: Parameter,
......@@ -763,8 +764,12 @@ class QKVParallelLinear(ColumnParallelLinear):
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == "v" and len(param_data.shape) == 2:
if self.use_fa_pad and param_data.shape[0]== 12288:
param_data = pad_weight(param.data, 32)
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
if self.use_fa_pad and param_data.shape[0]== 12288 and loaded_shard_id == "v" and len(param_data.shape) == 1:
param.data = pad_weight(param.data, 32)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......
......@@ -27,6 +27,8 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '0':
os.environ['GEMM_PAD'] = '1'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
......
......@@ -24,6 +24,7 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
import os
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -178,6 +179,8 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1':
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
......
......@@ -7,6 +7,7 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from torch.nn import LayerNorm
import os
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -101,6 +102,8 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1':
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(
......
......@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
import os
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -156,6 +157,8 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1':
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
......
......@@ -27,6 +27,7 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import Qwen2Config
import os
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -148,6 +149,8 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1':
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment