Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e58014d7
Commit
e58014d7
authored
Jun 28, 2024
by
zhuwenwen
Browse files
add gemm paddig
parent
e2df3544
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
4 deletions
+46
-4
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+44
-4
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+2
-0
No files found.
vllm/model_executor/layers/linear.py
View file @
e58014d7
...
...
@@ -44,6 +44,34 @@ def adjust_bitsandbytes_shard(param: Parameter,
return
quantized_size
,
quantized_offset
def
pad_weight
(
weight
:
torch
.
Tensor
,
num_pad
:
int
,
pad_dim
:
int
=
0
):
if
weight
.
dim
()
==
1
:
padding
=
torch
.
zeros
(
num_pad
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
0
)
elif
weight
.
dim
()
==
2
:
if
pad_dim
==
0
:
padding
=
torch
.
zeros
(
num_pad
,
weight
.
shape
[
1
],
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
0
)
elif
pad_dim
==
1
:
padding
=
torch
.
zeros
(
weight
.
shape
[
0
],
num_pad
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
1
)
else
:
raise
ValueError
(
"pad_dim must be 0 or 1"
)
else
:
raise
ValueError
(
"Weight tensor must be 1D or 2D"
)
return
padded_weight
def
gemm_bank_conf
(
weight
):
is_mul_of_2048
=
weight
%
2048
==
0
is_power_of_two
=
(
weight
&
(
weight
-
1
))
==
0
and
weight
!=
0
if
is_mul_of_2048
and
is_power_of_two
:
return
True
else
:
return
False
class
LinearMethodBase
(
QuantizeMethodBase
):
"""Base class for different (maybe quantized) linear methods."""
...
...
@@ -118,7 +146,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
if
bias
is
not
None
:
return
torch
.
matmul
(
x
,
weight
)
+
bias
else
:
return
torch
.
matmul
(
x
,
weight
)
if
gemm_bank_conf
(
weight
.
shape
[
1
]
-
32
)
and
os
.
environ
[
'GEMM_PAD'
]
==
'1'
:
return
torch
.
matmul
(
x
,
weight
[:,:
-
32
])
else
:
return
torch
.
matmul
(
x
,
weight
)
else
:
return
F
.
linear
(
x
,
weight
,
bias
)
...
...
@@ -806,6 +837,7 @@ class RowParallelLinear(LinearBase):
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for Fp8 scales.
...
...
@@ -831,10 +863,18 @@ class RowParallelLinear(LinearBase):
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
if
self
.
use_llama_nn
:
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
param_data
.
shape
[
0
],
-
1
)
# if self.use_llama_nn:
# loaded_weight = loaded_weight.transpose(0, 1)
# loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
# param_data.copy_(loaded_weight)
param_data
.
copy_
(
loaded_weight
)
if
self
.
use_llama_nn
:
if
gemm_bank_conf
(
param
.
data
.
shape
[
0
])
and
self
.
use_gemm_pad
:
param
.
data
=
pad_weight
(
param
.
data
,
32
)
param
.
data
=
param
.
data
.
transpose
(
0
,
1
)
param
.
data
=
param
.
data
.
reshape
(
param
.
data
.
shape
[
1
],
-
1
)
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
...
...
vllm/model_executor/model_loader/utils.py
View file @
e58014d7
...
...
@@ -25,6 +25,8 @@ def get_model_architecture(
if
architectures
==
[
'LlamaForCausalLM'
]
or
architectures
==
[
'ChatGLMModel'
]
or
architectures
==
[
'BaichuanForCausalLM'
]:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
if
os
.
getenv
(
'GEMM_PAD'
)
!=
'0'
:
os
.
environ
[
'GEMM_PAD'
]
=
'1'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment