Commit 1fabf3e1 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.0-dtk24.04.1' into v0.5.2-dtk24.04.1

# Conflicts:
#	vllm/model_executor/layers/linear.py
#	vllm/model_executor/models/baichuan.py
parents 1e77d04e 0b5e4e11
......@@ -59,17 +59,14 @@ git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的
```
1. 编译whl包并安装
VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel
python csrc/quantization/gptq/setup.py bdist_wheel
cd dist
pip install vllm*
cd csrc/quantization/gptq
python setup.py bdist_wheel
cd dist
pip install gptq_kernel
2. 源码编译安装
VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
cd csrc/quantization/gptq
python setup.py install
python csrc/quantization/gptq/setup.py install
```
#### 运行基础环境准备
......
......@@ -24,8 +24,8 @@ setup(
CUDAExtension(
name="gptq_kernels",
sources=[
"./torch_bindings.cpp",
"./q_gemm.cu",
"csrc/quantization/gptq/torch_bindings.cpp",
"csrc/quantization/gptq/q_gemm.cu",
],
extra_compile_args=extra_compile_args,
)
......
......@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs
import os
from vllm.model_executor.utils import gemm_bank_conf
logger = init_logger(__name__)
......@@ -105,7 +106,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
......@@ -125,6 +127,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn:
if gemm_bank_conf(weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
weight = weight[:,:-32]
if bias is not None:
if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight)
......
......@@ -22,11 +22,19 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM'] or architectures == ['QWenLMHeadModel'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
else:
os.environ['LLAMA_NN'] = '0'
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
......
......@@ -50,6 +50,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
......@@ -185,6 +186,8 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
......@@ -339,6 +342,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward(
self,
......@@ -419,7 +424,13 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
......@@ -29,7 +29,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from .interfaces import SupportsLoRA
......@@ -106,6 +108,8 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(
......@@ -363,6 +367,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward(
self,
......@@ -416,7 +422,13 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
......@@ -55,6 +55,8 @@ from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class LlamaMLP(nn.Module):
......@@ -162,6 +164,8 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
......@@ -384,6 +388,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
config.vocab_size, logit_scale)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward(
self,
......@@ -496,7 +502,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
......@@ -34,6 +34,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class QWenMLP(nn.Module):
def __init__(
......@@ -121,6 +124,8 @@ class QWenAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
......@@ -203,7 +208,6 @@ class QWenModel(nn.Module):
for _ in range(config.num_hidden_layers)
])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -245,6 +249,8 @@ class QWenLMHeadModel(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward(
self,
......@@ -322,6 +328,12 @@ class QWenLMHeadModel(nn.Module):
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
......@@ -53,6 +53,9 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class Qwen2MLP(nn.Module):
def __init__(
......@@ -153,6 +156,8 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
......@@ -327,6 +332,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward(
self,
......@@ -414,6 +421,12 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
......@@ -33,3 +33,32 @@ def set_weight_attrs(
assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value)
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):
if weight.dim() == 1:
padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif weight.dim() == 2:
if pad_dim == 0:
padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif pad_dim == 1:
padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=1)
else:
raise ValueError("pad_dim must be 0 or 1")
else:
raise ValueError("Weight tensor must be 1D or 2D")
padded_weight = padded_weight.contiguous()
return padded_weight
def gemm_bank_conf(weight):
is_mul_of_2048 = weight % 2048 == 0
is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0
if is_mul_of_2048 and is_power_of_two:
return True
else:
return False
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment