Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
835bd9fc
Commit
835bd9fc
authored
Jul 20, 2024
by
gaoqiong
Browse files
修改nn支持方式
parent
7fe40ced
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
165 additions
and
102 deletions
+165
-102
csrc/ops.h
csrc/ops.h
+1
-0
csrc/quantization/gptq/q_gemm.cu
csrc/quantization/gptq/q_gemm.cu
+31
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+4
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+4
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+25
-98
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+34
-2
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+34
-1
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+32
-1
No files found.
csrc/ops.h
View file @
835bd9fc
...
...
@@ -119,6 +119,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
...
...
csrc/quantization/gptq/q_gemm.cu
View file @
835bd9fc
...
...
@@ -1542,6 +1542,26 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
}
}
template
<
typename
T
>
__global__
void
trans_w16_gemm_cudakernel
(
int64_t
num_kernels
,
T
*
dst
,
const
T
*
src
,
int64_t
row
,
int64_t
col
)
{
int64_t
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
int64_t
j
=
id
%
row
;
//dst的列id
int64_t
i
=
id
/
row
;
dst
[
i
*
row
+
j
]
=
src
[
j
*
col
+
i
];
}
void
trans_w16_gemm_cuda
(
half
*
dst
,
const
half
*
src
,
int64_t
row
,
int64_t
col
){
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
num_kernels
=
row
*
col
;
int
block_size
=
256
;
trans_w16_gemm_cudakernel
<<<
(
num_kernels
+
block_size
-
1
)
/
block_size
,
block_size
,
0
,
stream
>>>
(
num_kernels
,
dst
,
src
,
row
,
col
);
}
__global__
void
shuffle_4bit_kernel
(
uint32_t
*
__restrict__
b_q_weight
,
const
int
size_k
,
const
int
size_n
)
{
int
n
=
blockIdx
.
x
*
THREADS_X
+
threadIdx
.
x
;
...
...
@@ -1847,6 +1867,17 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return
c
;
}
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
){
//row是原矩阵的行,col是原矩阵的列
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
src
));
vllm
::
gptq
::
trans_w16_gemm_cuda
(
(
half
*
)
dst
.
data_ptr
(),
(
const
half
*
)
src
.
data_ptr
(),
row
,
col
);
}
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
q_weight
));
vllm
::
gptq
::
shuffle_exllama_weight
(
...
...
csrc/torch_bindings.cpp
View file @
835bd9fc
...
...
@@ -159,6 +159,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// trans w16
ops
.
def
(
"trans_w16_gemm(Tensor! dst, Tensor src, int row,int col) -> ()"
);
ops
.
impl
(
"trans_w16_gemm"
,
torch
::
kCUDA
,
&
trans_w16_gemm
);
// Quantized GEMM for SqueezeLLM.
ops
.
def
(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
...
...
vllm/_custom_ops.py
View file @
835bd9fc
...
...
@@ -164,6 +164,10 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit
:
int
)
->
None
:
torch
.
ops
.
_C
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# trans_w16
def
trans_w16_gemm
(
dst
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
row
:
int
,
col
:
int
)
->
None
:
torch
.
ops
.
_C
.
trans_w16_gemm
(
dst
,
src
,
row
,
col
)
# squeezellm
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/linear.py
View file @
835bd9fc
...
...
@@ -14,6 +14,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
import
os
logger
=
init_logger
(
__name__
)
...
...
@@ -42,34 +43,6 @@ def adjust_bitsandbytes_shard(param: Parameter,
return
quantized_size
,
quantized_offset
def
pad_weight
(
weight
:
torch
.
Tensor
,
num_pad
:
int
,
pad_dim
:
int
=
0
):
if
weight
.
dim
()
==
1
:
padding
=
torch
.
zeros
(
num_pad
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
0
)
elif
weight
.
dim
()
==
2
:
if
pad_dim
==
0
:
padding
=
torch
.
zeros
(
num_pad
,
weight
.
shape
[
1
],
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
0
)
elif
pad_dim
==
1
:
padding
=
torch
.
zeros
(
weight
.
shape
[
0
],
num_pad
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
1
)
else
:
raise
ValueError
(
"pad_dim must be 0 or 1"
)
else
:
raise
ValueError
(
"Weight tensor must be 1D or 2D"
)
return
padded_weight
def
gemm_bank_conf
(
weight
):
is_mul_of_2048
=
weight
%
2048
==
0
is_power_of_two
=
(
weight
&
(
weight
-
1
))
==
0
and
weight
!=
0
if
is_mul_of_2048
and
is_power_of_two
:
return
True
else
:
return
False
class
LinearMethodBase
(
QuantizeMethodBase
):
"""Base class for different (maybe quantized) linear methods."""
...
...
@@ -115,6 +88,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
...
...
@@ -134,20 +108,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
weight
=
layer
.
weight
#print("**************matmul weight.shape:",weight.shape)
#print("self.use_llama_nn:",self.use_llama_nn)
if
self
.
separate_bias_add
:
#print("********self.separate_bias_add")
if
bias
is
not
None
:
return
F
.
linear
(
x
,
weight
)
+
bias
return
F
.
linear
(
x
,
weight
)
if
self
.
use_llama_nn
:
weight
=
weight
.
reshape
(
weight
.
shape
[
1
],
-
1
)
# print("**************matmul input.shape:",x.shape)
# print("**************matmul weight.shape:",weight.shape)
if
bias
is
not
None
:
return
torch
.
matmul
(
x
,
weight
)
+
bias
return
torch
.
matmul
(
x
,
weight
)
+
bias
else
:
if
gemm_bank_conf
(
weight
.
shape
[
1
]
-
32
)
and
os
.
environ
[
'GEMM_PAD'
]
==
'1'
:
return
torch
.
matmul
(
x
,
weight
[:,:
-
32
])
else
:
return
torch
.
matmul
(
x
,
weight
)
return
torch
.
matmul
(
x
,
weight
)
else
:
return
F
.
linear
(
x
,
weight
,
bias
)
...
...
@@ -308,7 +286,6 @@ class ColumnParallelLinear(LinearBase):
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for Fp8 scales.
...
...
@@ -330,9 +307,6 @@ class ColumnParallelLinear(LinearBase):
shard_id
=
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
if
self
.
use_llama_nn
:
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
param_data
.
shape
[
0
],
-
1
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
input_
):
...
...
@@ -397,8 +371,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -477,21 +449,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
shard_size
=
loaded_weight
.
shape
[
output_dim
]
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
\
loaded_shard_id
if
self
.
use_llama_nn
:
param_data_
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
...
...
@@ -527,17 +493,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
self
.
use_llama_nn
:
assert
param_data_
.
shape
==
loaded_weight
.
shape
param_data_
.
copy_
(
loaded_weight
)
if
loaded_shard_id
==
1
and
len
(
param_data
.
shape
)
==
2
:
param_data
=
param_data
.
transpose
(
0
,
1
)
param
.
data
=
param_data
.
reshape
(
param_data
.
shape
[
1
],
-
1
)
else
:
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
...
...
@@ -597,6 +555,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# k_proj
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# v_proj
]
super
().
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
bias
=
bias
,
...
...
@@ -604,8 +563,6 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -713,14 +670,9 @@ class QKVParallelLinear(ColumnParallelLinear):
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
if
self
.
use_llama_nn
:
param_data_
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
shard_id
=
tp_rank
else
:
...
...
@@ -752,25 +704,15 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
if
len
(
param_data
.
shape
)
==
0
:
param_data
=
param_data
.
reshape
(
1
)
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
self
.
use_llama_nn
:
assert
param_data_
.
shape
==
loaded_weight
.
shape
param_data_
.
copy_
(
loaded_weight
)
if
loaded_shard_id
==
"v"
and
len
(
param_data
.
shape
)
==
2
:
if
self
.
use_fa_pad
and
param_data
.
shape
[
0
]
==
12288
:
param_data
=
pad_weight
(
param
.
data
,
32
)
param_data
=
param_data
.
transpose
(
0
,
1
)
param
.
data
=
param_data
.
reshape
(
param_data
.
shape
[
1
],
-
1
)
if
self
.
use_fa_pad
and
param_data
.
shape
[
0
]
==
12288
and
loaded_shard_id
==
"v"
and
len
(
param_data
.
shape
)
==
1
:
param
.
data
=
pad_weight
(
param
.
data
,
32
)
else
:
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
RowParallelLinear
(
LinearBase
):
...
...
@@ -839,8 +781,6 @@ class RowParallelLinear(LinearBase):
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for Fp8 scales.
...
...
@@ -866,20 +806,7 @@ class RowParallelLinear(LinearBase):
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
if
self
.
use_llama_nn
:
if
not
self
.
use_gemm_pad
:
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
param_data
.
shape
[
0
],
-
1
)
param_data
.
copy_
(
loaded_weight
)
else
:
param_data
.
copy_
(
loaded_weight
)
if
gemm_bank_conf
(
param
.
data
.
shape
[
0
])
and
self
.
use_gemm_pad
:
param
.
data
=
pad_weight
(
param
.
data
,
32
)
param
.
data
=
param
.
data
.
transpose
(
0
,
1
)
param
.
data
=
param
.
data
.
reshape
(
param
.
data
.
shape
[
1
],
-
1
)
else
:
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
...
...
vllm/model_executor/models/llama.py
View file @
835bd9fc
...
...
@@ -27,6 +27,7 @@ import torch
from
torch
import
nn
from
transformers
import
LlamaConfig
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
...
@@ -50,6 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm
import
_custom_ops
as
ops
class
LlamaMLP
(
nn
.
Module
):
...
...
@@ -363,6 +365,7 @@ class LlamaForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
forward
(
self
,
...
...
@@ -438,8 +441,37 @@ class LlamaForCausalLM(nn.Module):
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
,
weight
in
params_dict
.
items
():
#print("key:\n",key)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
# if layername=="model.layers.0.self_attn.qkv_proj.weight":
# print("weight.data[0:5][0:5]:",weight.data[0:5][0:5])
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
...
...
vllm/model_executor/models/qwen.py
View file @
835bd9fc
...
...
@@ -10,6 +10,9 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -29,7 +32,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm
import
_custom_ops
as
ops
class
QWenMLP
(
nn
.
Module
):
def
__init__
(
...
...
@@ -199,6 +202,7 @@ class QWenModel(nn.Module):
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
forward
(
self
,
...
...
@@ -237,6 +241,7 @@ class QWenLMHeadModel(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
forward
(
self
,
...
...
@@ -292,3 +297,31 @@ class QWenLMHeadModel(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
,
weight
in
params_dict
.
items
():
#print("key:\n",key)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
vllm/model_executor/models/qwen2.py
View file @
835bd9fc
...
...
@@ -28,6 +28,7 @@ import torch
from
torch
import
nn
from
transformers
import
Qwen2Config
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
...
@@ -48,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm
import
_custom_ops
as
ops
class
Qwen2MLP
(
nn
.
Module
):
def
__init__
(
...
...
@@ -322,6 +323,7 @@ class Qwen2ForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
forward
(
self
,
...
...
@@ -382,3 +384,32 @@ class Qwen2ForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
:
#以上代码模型权重已经加载完了
#以下代码使用正则匹配来找出要修改的weight
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
]
#合并所有关键词为一个正则表达式
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
,
weight
in
params_dict
.
items
():
#print("key:\n",key)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
#print(layername)
# print(weight.data)
#创建一个跟value一样大的tensor
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment