Commit 835bd9fc authored by gaoqiong's avatar gaoqiong
Browse files

修改nn支持方式

parent 7fe40ced
...@@ -119,6 +119,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -119,6 +119,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale); // torch::Tensor& scale);
......
...@@ -1542,6 +1542,26 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, ...@@ -1542,6 +1542,26 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
} }
} }
template <typename T>
__global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* src,int64_t row,int64_t col)
{
int64_t id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int64_t j=id%row; //dst的列id
int64_t i=id/row;
dst[i*row+j]=src[j*col+i];
}
void trans_w16_gemm_cuda(half* dst,const half* src,int64_t row,int64_t col){
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t num_kernels=row*col;
int block_size=256;
trans_w16_gemm_cudakernel<<<(num_kernels+block_size-1)/block_size,block_size, 0, stream>>>(num_kernels,dst,src,row,col);
}
__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight, __global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
const int size_k, const int size_n) { const int size_k, const int size_n) {
int n = blockIdx.x * THREADS_X + threadIdx.x; int n = blockIdx.x * THREADS_X + threadIdx.x;
...@@ -1847,6 +1867,17 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -1847,6 +1867,17 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return c; return c;
} }
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col){
//row是原矩阵的行,col是原矩阵的列
const at::cuda::OptionalCUDAGuard device_guard(device_of(src));
vllm::gptq::trans_w16_gemm_cuda(
(half*)dst.data_ptr(),
(const half*)src.data_ptr(),
row,
col
);
}
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) { void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight( vllm::gptq::shuffle_exllama_weight(
......
...@@ -159,6 +159,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -159,6 +159,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row,int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantized GEMM for SqueezeLLM. // Quantized GEMM for SqueezeLLM.
ops.def( ops.def(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor " "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
......
...@@ -164,6 +164,10 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, ...@@ -164,6 +164,10 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None: bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# squeezellm # squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
......
...@@ -14,6 +14,7 @@ from vllm.logger import init_logger ...@@ -14,6 +14,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
import os import os
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -42,34 +43,6 @@ def adjust_bitsandbytes_shard(param: Parameter, ...@@ -42,34 +43,6 @@ def adjust_bitsandbytes_shard(param: Parameter,
return quantized_size, quantized_offset return quantized_size, quantized_offset
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):
if weight.dim() == 1:
padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif weight.dim() == 2:
if pad_dim == 0:
padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif pad_dim == 1:
padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=1)
else:
raise ValueError("pad_dim must be 0 or 1")
else:
raise ValueError("Weight tensor must be 1D or 2D")
return padded_weight
def gemm_bank_conf(weight):
is_mul_of_2048 = weight % 2048 == 0
is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0
if is_mul_of_2048 and is_power_of_two:
return True
else:
return False
class LinearMethodBase(QuantizeMethodBase): class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
...@@ -115,6 +88,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -115,6 +88,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False): def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
...@@ -134,20 +108,24 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -134,20 +108,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight weight = layer.weight
#print("**************matmul weight.shape:",weight.shape)
#print("self.use_llama_nn:",self.use_llama_nn)
if self.separate_bias_add: if self.separate_bias_add:
#print("********self.separate_bias_add")
if bias is not None: if bias is not None:
return F.linear(x, weight) + bias return F.linear(x, weight) + bias
return F.linear(x, weight) return F.linear(x, weight)
if self.use_llama_nn: if self.use_llama_nn:
weight = weight.reshape(weight.shape[1], -1) # print("**************matmul input.shape:",x.shape)
# print("**************matmul weight.shape:",weight.shape)
if bias is not None: if bias is not None:
return torch.matmul(x, weight) + bias return torch.matmul(x, weight) +bias
else: else:
if gemm_bank_conf(weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1': return torch.matmul(x, weight)
return torch.matmul(x, weight[:,:-32])
else:
return torch.matmul(x, weight)
else: else:
return F.linear(x, weight, bias) return F.linear(x, weight, bias)
...@@ -308,7 +286,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -308,7 +286,6 @@ class ColumnParallelLinear(LinearBase):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales. # Special case for Fp8 scales.
...@@ -330,9 +307,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -330,9 +307,6 @@ class ColumnParallelLinear(LinearBase):
shard_id=0) shard_id=0)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def forward(self, input_): def forward(self, input_):
...@@ -397,8 +371,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -397,8 +371,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config) quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -477,21 +449,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -477,21 +449,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin. # Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False) use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes: if use_bitsandbytes:
shard_size = loaded_weight.shape[output_dim] shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \ shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id loaded_shard_id
if self.use_llama_nn: param_data = param_data.narrow(output_dim, shard_offset,
param_data_ = param_data.narrow(output_dim, shard_offset, shard_size)
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
...@@ -527,17 +493,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -527,17 +493,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == 1 and len(param_data.shape) == 2:
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear):
...@@ -597,6 +555,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -597,6 +555,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self.num_kv_heads * self.head_size * tp_size, # k_proj self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj self.num_kv_heads * self.head_size * tp_size, # v_proj
] ]
super().__init__(input_size=input_size, super().__init__(input_size=input_size,
output_size=output_size, output_size=output_size,
bias=bias, bias=bias,
...@@ -604,8 +563,6 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -604,8 +563,6 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config) quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -713,14 +670,9 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -713,14 +670,9 @@ class QKVParallelLinear(ColumnParallelLinear):
} }
shard_size, shard_offset = adjust_bitsandbytes_shard( shard_size, shard_offset = adjust_bitsandbytes_shard(
param, orig_qkv_offsets, loaded_shard_id) param, orig_qkv_offsets, loaded_shard_id)
if self.use_llama_nn: param_data = param_data.narrow(output_dim, shard_offset,
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
shard_id = tp_rank shard_id = tp_rank
else: else:
...@@ -752,25 +704,15 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -752,25 +704,15 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same " "QKVParallelLinear, assume the weight is the same "
"for all partitions.") "for all partitions.")
if len(param_data.shape) == 0: if len(param_data.shape) == 0:
param_data = param_data.reshape(1) param_data = param_data.reshape(1)
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
if self.use_llama_nn: assert param_data.shape == loaded_weight.shape
assert param_data_.shape == loaded_weight.shape param_data.copy_(loaded_weight)
param_data_.copy_(loaded_weight)
if loaded_shard_id == "v" and len(param_data.shape) == 2:
if self.use_fa_pad and param_data.shape[0]== 12288:
param_data = pad_weight(param.data, 32)
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
if self.use_fa_pad and param_data.shape[0]== 12288 and loaded_shard_id == "v" and len(param_data.shape) == 1:
param.data = pad_weight(param.data, 32)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase): class RowParallelLinear(LinearBase):
...@@ -839,8 +781,6 @@ class RowParallelLinear(LinearBase): ...@@ -839,8 +781,6 @@ class RowParallelLinear(LinearBase):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales. # Special case for Fp8 scales.
...@@ -866,20 +806,7 @@ class RowParallelLinear(LinearBase): ...@@ -866,20 +806,7 @@ class RowParallelLinear(LinearBase):
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
if self.use_llama_nn:
if not self.use_gemm_pad:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight)
else:
param_data.copy_(loaded_weight)
if gemm_bank_conf(param.data.shape[0]) and self.use_gemm_pad:
param.data = pad_weight(param.data, 32)
param.data = param.data.transpose(0, 1)
param.data=param.data.reshape(param.data.shape[1],-1)
else:
param_data.copy_(loaded_weight)
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
......
...@@ -27,6 +27,7 @@ import torch ...@@ -27,6 +27,7 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
import os import os
import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
...@@ -50,6 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -50,6 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.utils import is_hip, print_warning_once from vllm.utils import is_hip, print_warning_once
from vllm import _custom_ops as ops
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -363,6 +365,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -363,6 +365,7 @@ class LlamaForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
...@@ -438,8 +441,37 @@ class LlamaForCausalLM(nn.Module): ...@@ -438,8 +441,37 @@ class LlamaForCausalLM(nn.Module):
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername)
if matches:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
# if layername=="model.layers.0.self_attn.qkv_proj.weight":
# print("weight.data[0:5][0:5]:",weight.data[0:5][0:5])
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state # make sure to leave KV cache scale factors in a known good (dummy) state
......
...@@ -10,6 +10,9 @@ import torch ...@@ -10,6 +10,9 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
import os
import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -29,7 +32,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -29,7 +32,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm import _custom_ops as ops
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
def __init__( def __init__(
...@@ -199,6 +202,7 @@ class QWenModel(nn.Module): ...@@ -199,6 +202,7 @@ class QWenModel(nn.Module):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
...@@ -237,6 +241,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -237,6 +241,7 @@ class QWenLMHeadModel(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
...@@ -292,3 +297,31 @@ class QWenLMHeadModel(nn.Module): ...@@ -292,3 +297,31 @@ class QWenLMHeadModel(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername)
if matches:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
...@@ -28,6 +28,7 @@ import torch ...@@ -28,6 +28,7 @@ import torch
from torch import nn from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
import os import os
import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
...@@ -48,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -48,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm import _custom_ops as ops
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
def __init__( def __init__(
...@@ -322,6 +323,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -322,6 +323,7 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
...@@ -382,3 +384,32 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -382,3 +384,32 @@ class Qwen2ForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
#print("key:\n",key)
matches = re.findall(combined_words, layername)
if matches:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment