Commit 835bd9fc authored by gaoqiong's avatar gaoqiong
Browse files

修改nn支持方式

parent 7fe40ced
......@@ -119,6 +119,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
......
......@@ -1542,6 +1542,26 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
}
}
template <typename T>
__global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* src,int64_t row,int64_t col)
{
int64_t id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int64_t j=id%row; //dst的列id
int64_t i=id/row;
dst[i*row+j]=src[j*col+i];
}
void trans_w16_gemm_cuda(half* dst,const half* src,int64_t row,int64_t col){
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t num_kernels=row*col;
int block_size=256;
trans_w16_gemm_cudakernel<<<(num_kernels+block_size-1)/block_size,block_size, 0, stream>>>(num_kernels,dst,src,row,col);
}
__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
const int size_k, const int size_n) {
int n = blockIdx.x * THREADS_X + threadIdx.x;
......@@ -1847,6 +1867,17 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return c;
}
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col){
//row是原矩阵的行,col是原矩阵的列
const at::cuda::OptionalCUDAGuard device_guard(device_of(src));
vllm::gptq::trans_w16_gemm_cuda(
(half*)dst.data_ptr(),
(const half*)src.data_ptr(),
row,
col
);
}
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight(
......
......@@ -159,6 +159,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row,int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantized GEMM for SqueezeLLM.
ops.def(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
......
......@@ -164,6 +164,10 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
......
......@@ -14,6 +14,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
import os
logger = init_logger(__name__)
......@@ -42,34 +43,6 @@ def adjust_bitsandbytes_shard(param: Parameter,
return quantized_size, quantized_offset
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):
if weight.dim() == 1:
padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif weight.dim() == 2:
if pad_dim == 0:
padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif pad_dim == 1:
padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=1)
else:
raise ValueError("pad_dim must be 0 or 1")
else:
raise ValueError("Weight tensor must be 1D or 2D")
return padded_weight
def gemm_bank_conf(weight):
is_mul_of_2048 = weight % 2048 == 0
is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0
if is_mul_of_2048 and is_power_of_two:
return True
else:
return False
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
......@@ -115,6 +88,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
......@@ -134,20 +108,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
#print("**************matmul weight.shape:",weight.shape)
#print("self.use_llama_nn:",self.use_llama_nn)
if self.separate_bias_add:
#print("********self.separate_bias_add")
if bias is not None:
return F.linear(x, weight) + bias
return F.linear(x, weight)
if self.use_llama_nn:
weight = weight.reshape(weight.shape[1], -1)
# print("**************matmul input.shape:",x.shape)
# print("**************matmul weight.shape:",weight.shape)
if bias is not None:
return torch.matmul(x, weight) + bias
return torch.matmul(x, weight) +bias
else:
if gemm_bank_conf(weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
return torch.matmul(x, weight[:,:-32])
else:
return torch.matmul(x, weight)
return torch.matmul(x, weight)
else:
return F.linear(x, weight, bias)
......@@ -308,7 +286,6 @@ class ColumnParallelLinear(LinearBase):
})
else:
self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
......@@ -330,9 +307,6 @@ class ColumnParallelLinear(LinearBase):
shard_id=0)
assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight)
def forward(self, input_):
......@@ -397,8 +371,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self,
param: Parameter,
......@@ -477,21 +449,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
if self.use_llama_nn:
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
......@@ -527,17 +493,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == 1 and len(param_data.shape) == 2:
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
......@@ -597,6 +555,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
super().__init__(input_size=input_size,
output_size=output_size,
bias=bias,
......@@ -604,8 +563,6 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def weight_loader(self,
param: Parameter,
......@@ -713,14 +670,9 @@ class QKVParallelLinear(ColumnParallelLinear):
}
shard_size, shard_offset = adjust_bitsandbytes_shard(
param, orig_qkv_offsets, loaded_shard_id)
if self.use_llama_nn:
param_data_ = param_data.narrow(output_dim, shard_offset,
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
shard_id = tp_rank
else:
......@@ -752,25 +704,15 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if len(param_data.shape) == 0:
param_data = param_data.reshape(1)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == "v" and len(param_data.shape) == 2:
if self.use_fa_pad and param_data.shape[0]== 12288:
param_data = pad_weight(param.data, 32)
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
if self.use_fa_pad and param_data.shape[0]== 12288 and loaded_shard_id == "v" and len(param_data.shape) == 1:
param.data = pad_weight(param.data, 32)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
......@@ -839,8 +781,6 @@ class RowParallelLinear(LinearBase):
})
else:
self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
......@@ -866,20 +806,7 @@ class RowParallelLinear(LinearBase):
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
if not self.use_gemm_pad:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight)
else:
param_data.copy_(loaded_weight)
if gemm_bank_conf(param.data.shape[0]) and self.use_gemm_pad:
param.data = pad_weight(param.data, 32)
param.data = param.data.transpose(0, 1)
param.data=param.data.reshape(param.data.shape[1],-1)
else:
param_data.copy_(loaded_weight)
param_data.copy_(loaded_weight)
def forward(self, input_):
# Set up backprop all-reduce.
......
......@@ -27,6 +27,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -50,6 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip, print_warning_once
from vllm import _custom_ops as ops
class LlamaMLP(nn.Module):
......@@ -363,6 +365,7 @@ class LlamaForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -438,8 +441,37 @@ class LlamaForCausalLM(nn.Module):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername)
if matches:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
# if layername=="model.layers.0.self_attn.qkv_proj.weight":
# print("weight.data[0:5][0:5]:",weight.data[0:5][0:5])
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
......
......@@ -10,6 +10,9 @@ import torch
from torch import nn
from transformers import PretrainedConfig
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -29,7 +32,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm import _custom_ops as ops
class QWenMLP(nn.Module):
def __init__(
......@@ -199,6 +202,7 @@ class QWenModel(nn.Module):
for _ in range(config.num_hidden_layers)
])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -237,6 +241,7 @@ class QWenLMHeadModel(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -292,3 +297,31 @@ class QWenLMHeadModel(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername)
if matches:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
......@@ -28,6 +28,7 @@ import torch
from torch import nn
from transformers import Qwen2Config
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
......@@ -48,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm import _custom_ops as ops
class Qwen2MLP(nn.Module):
def __init__(
......@@ -322,6 +323,7 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -382,3 +384,32 @@ class Qwen2ForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername)
if matches:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment