Commit 1fabf3e1 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.0-dtk24.04.1' into v0.5.2-dtk24.04.1

# Conflicts:
#	vllm/model_executor/layers/linear.py
#	vllm/model_executor/models/baichuan.py
parents 1e77d04e 0b5e4e11
...@@ -59,17 +59,14 @@ git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的 ...@@ -59,17 +59,14 @@ git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的
``` ```
1. 编译whl包并安装 1. 编译whl包并安装
VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel
python csrc/quantization/gptq/setup.py bdist_wheel
cd dist cd dist
pip install vllm* pip install vllm*
cd csrc/quantization/gptq
python setup.py bdist_wheel
cd dist
pip install gptq_kernel pip install gptq_kernel
2. 源码编译安装 2. 源码编译安装
VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
cd csrc/quantization/gptq python csrc/quantization/gptq/setup.py install
python setup.py install
``` ```
#### 运行基础环境准备 #### 运行基础环境准备
......
...@@ -24,8 +24,8 @@ setup( ...@@ -24,8 +24,8 @@ setup(
CUDAExtension( CUDAExtension(
name="gptq_kernels", name="gptq_kernels",
sources=[ sources=[
"./torch_bindings.cpp", "csrc/quantization/gptq/torch_bindings.cpp",
"./q_gemm.cu", "csrc/quantization/gptq/q_gemm.cu",
], ],
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
) )
......
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
import os import os
from vllm.model_executor.utils import gemm_bank_conf
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -105,6 +106,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -105,6 +106,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self): def __init__(self):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
...@@ -125,6 +127,9 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -125,6 +127,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn: if self.use_llama_nn:
if gemm_bank_conf(weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
weight = weight[:,:-32]
if bias is not None: if bias is not None:
if len(x.shape) == 2: if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight) return torch.addmm(bias, x, layer.weight)
......
...@@ -22,11 +22,19 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -22,11 +22,19 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture( def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM'] or architectures == ['QWenLMHeadModel'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']: support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
else: else:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None if (model_config.quantization is not None
......
...@@ -50,6 +50,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput ...@@ -50,6 +50,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
...@@ -185,6 +186,8 @@ class BaiChuanAttention(nn.Module): ...@@ -185,6 +186,8 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
...@@ -339,6 +342,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -339,6 +342,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
...@@ -420,6 +425,12 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -420,6 +425,12 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -29,7 +29,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -29,7 +29,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -106,6 +108,8 @@ class GLMAttention(nn.Module): ...@@ -106,6 +108,8 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn( context_layer = self.attn(
...@@ -363,6 +367,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -363,6 +367,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
...@@ -417,6 +423,12 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -417,6 +423,12 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -55,6 +55,8 @@ from .interfaces import SupportsLoRA ...@@ -55,6 +55,8 @@ from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers from .utils import is_pp_missing_parameter, make_layers
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -162,6 +164,8 @@ class LlamaAttention(nn.Module): ...@@ -162,6 +164,8 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
...@@ -384,6 +388,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -384,6 +388,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
...@@ -497,6 +503,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -497,6 +503,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -34,6 +34,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput ...@@ -34,6 +34,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
def __init__( def __init__(
...@@ -121,6 +124,8 @@ class QWenAttention(nn.Module): ...@@ -121,6 +124,8 @@ class QWenAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
...@@ -203,7 +208,6 @@ class QWenModel(nn.Module): ...@@ -203,7 +208,6 @@ class QWenModel(nn.Module):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
...@@ -245,6 +249,8 @@ class QWenLMHeadModel(nn.Module): ...@@ -245,6 +249,8 @@ class QWenLMHeadModel(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
...@@ -322,6 +328,12 @@ class QWenLMHeadModel(nn.Module): ...@@ -322,6 +328,12 @@ class QWenLMHeadModel(nn.Module):
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -53,6 +53,9 @@ from vllm.utils import print_warning_once ...@@ -53,6 +53,9 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
def __init__( def __init__(
...@@ -153,6 +156,8 @@ class Qwen2Attention(nn.Module): ...@@ -153,6 +156,8 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
...@@ -327,6 +332,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -327,6 +332,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
...@@ -414,6 +421,12 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -414,6 +421,12 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
...@@ -33,3 +33,32 @@ def set_weight_attrs( ...@@ -33,3 +33,32 @@ def set_weight_attrs(
assert not hasattr( assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}") weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value) setattr(weight, key, value)
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):
if weight.dim() == 1:
padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif weight.dim() == 2:
if pad_dim == 0:
padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif pad_dim == 1:
padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=1)
else:
raise ValueError("pad_dim must be 0 or 1")
else:
raise ValueError("Weight tensor must be 1D or 2D")
padded_weight = padded_weight.contiguous()
return padded_weight
def gemm_bank_conf(weight):
is_mul_of_2048 = weight % 2048 == 0
is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0
if is_mul_of_2048 and is_power_of_two:
return True
else:
return False
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment