Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a53046b1
Unverified
Commit
a53046b1
authored
Nov 05, 2024
by
Michael Goin
Committed by
GitHub
Nov 05, 2024
Browse files
[Model] Support quantization of PixtralHFTransformer for PixtralHF (#9921)
Signed-off-by:
mgoin
<
michael@neuralmagic.com
>
parent
731aec5b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
40 deletions
+90
-40
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+30
-0
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+60
-40
No files found.
vllm/model_executor/layers/activation.py
View file @
a53046b1
...
@@ -299,3 +299,33 @@ def get_act_fn(
...
@@ -299,3 +299,33 @@ def get_act_fn(
return
ScaledActivation
(
act_fn
,
intermediate_size
,
input_is_parallel
,
return
ScaledActivation
(
act_fn
,
intermediate_size
,
input_is_parallel
,
params_dtype
)
params_dtype
)
return
act_fn
return
act_fn
_ACTIVATION_AND_MUL_REGISTRY
=
LazyDict
({
"gelu"
:
lambda
:
GeluAndMul
(),
"silu"
:
lambda
:
SiluAndMul
(),
})
def
get_act_and_mul_fn
(
act_fn_name
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
intermediate_size
:
Optional
[
int
]
=
None
,
input_is_parallel
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
nn
.
Module
:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name
=
act_fn_name
.
lower
()
if
act_fn_name
not
in
_ACTIVATION_AND_MUL_REGISTRY
:
raise
ValueError
(
f
"Activation function
{
act_fn_name
!
r
}
is not supported."
)
act_fn
=
_ACTIVATION_AND_MUL_REGISTRY
[
act_fn_name
]
if
(
quant_config
is
not
None
and
act_fn_name
in
quant_config
.
get_scaled_act_names
()):
if
intermediate_size
is
None
:
raise
ValueError
(
"intermediate_size must be specified for scaled "
"activation functions."
)
return
ScaledActivation
(
act_fn
,
intermediate_size
,
input_is_parallel
,
params_dtype
)
return
act_fn
vllm/model_executor/models/pixtral.py
View file @
a53046b1
...
@@ -19,8 +19,11 @@ from vllm.attention import AttentionMetadata
...
@@ -19,8 +19,11 @@ from vllm.attention import AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
MultiModalConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_
and_mul_
fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -798,20 +801,24 @@ class PixtralHFMLP(nn.Module):
...
@@ -798,20 +801,24 @@ class PixtralHFMLP(nn.Module):
super
().
__init__
()
super
().
__init__
()
assert
config
.
intermediate_size
is
not
None
assert
config
.
intermediate_size
is
not
None
# TODO: Use quant_config and prefix after optimizing this
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_proj
=
nn
.
Linear
(
config
.
hidden_size
,
input_size
=
config
.
hidden_size
,
config
.
intermediate_size
,
output_sizes
=
[
config
.
intermediate_size
]
*
2
,
bias
=
False
)
bias
=
False
,
self
.
up_proj
=
nn
.
Linear
(
config
.
hidden_size
,
quant_config
=
quant_config
,
config
.
intermediate_size
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
bias
=
False
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
config
.
intermediate_size
,
self
.
down_proj
=
nn
.
Linear
(
config
.
intermediate_size
,
output_size
=
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
,
bias
=
False
)
quant_config
=
quant_config
,
self
.
act
=
get_act_fn
(
config
.
hidden_act
)
prefix
=
f
"
{
prefix
}
.down_proj"
)
self
.
act_and_mul
=
get_act_and_mul_fn
(
config
.
hidden_act
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
down_proj
(
self
.
act
(
self
.
gate_proj
(
x
))
*
self
.
up_proj
(
x
))
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_and_mul
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
PixtralHFAttention
(
nn
.
Module
):
class
PixtralHFAttention
(
nn
.
Module
):
...
@@ -830,21 +837,21 @@ class PixtralHFAttention(nn.Module):
...
@@ -830,21 +837,21 @@ class PixtralHFAttention(nn.Module):
self
.
n_heads
=
config
.
num_attention_heads
self
.
n_heads
=
config
.
num_attention_heads
self
.
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
self
.
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
config
.
hidden_size
,
# TODO: Use quant_config and prefix after optimizing this
head_size
=
self
.
head_dim
,
self
.
q_proj
=
nn
.
Linear
(
config
.
hidden_size
,
total_num_heads
=
self
.
n_heads
,
config
.
hidden_siz
e
,
bias
=
Fals
e
,
bias
=
False
)
quant_config
=
quant_config
,
self
.
k_proj
=
nn
.
Linear
(
config
.
hidden_size
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
config
.
hidden_size
,
)
bias
=
False
)
self
.
o_proj
=
RowParallelLinear
(
self
.
v_proj
=
nn
.
Linear
(
config
.
hidden_size
,
input_size
=
config
.
hidden_size
,
config
.
hidden_size
,
output_size
=
config
.
hidden_size
,
bias
=
False
)
bias
=
False
,
self
.
o_proj
=
nn
.
Linear
(
config
.
hidden_size
,
quant_config
=
quant_config
,
config
.
hidden_size
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
bias
=
False
)
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -854,13 +861,13 @@ class PixtralHFAttention(nn.Module):
...
@@ -854,13 +861,13 @@ class PixtralHFAttention(nn.Module):
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
batch
,
patches
,
_
=
hidden_states
.
size
()
batch
,
patches
,
_
=
hidden_states
.
size
()
q
=
self
.
q_proj
(
hidden_states
)
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
k
=
self
.
k_proj
(
hidden_states
)
q
,
k
,
v
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
v
=
self
.
v_proj
(
hidden_states
)
# Transpose q and k to apply HF's Rotary Position Embedding
# Transpose q and k to apply HF's Rotary Position Embedding
q
=
q
.
view
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
q
=
q
.
view
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
v
=
v
.
view
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
cos
,
sin
=
position_embeddings
cos
,
sin
=
position_embeddings
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
unsqueeze_dim
=
0
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
unsqueeze_dim
=
0
)
...
@@ -868,22 +875,21 @@ class PixtralHFAttention(nn.Module):
...
@@ -868,22 +875,21 @@ class PixtralHFAttention(nn.Module):
# Transpose q and k back for attention
# Transpose q and k back for attention
q
=
q
.
transpose
(
1
,
2
).
contiguous
()
q
=
q
.
transpose
(
1
,
2
).
contiguous
()
k
=
k
.
transpose
(
1
,
2
).
contiguous
()
k
=
k
.
transpose
(
1
,
2
).
contiguous
()
v
=
v
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
out
=
xops
.
memory_efficient_attention
(
q
,
out
=
xops
.
memory_efficient_attention
(
q
,
k
,
k
,
v
,
v
,
attn_bias
=
attention_mask
)
attn_bias
=
attention_mask
)
else
:
else
:
v
=
v
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
v
=
v
.
transpose
(
1
,
2
)
self
.
head_dim
).
transpose
(
1
,
2
)
out
=
nn
.
functional
.
scaled_dot_product_attention
(
out
=
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attention_mask
)
q
,
k
,
v
,
attn_mask
=
attention_mask
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
reshape
(
batch
,
patches
,
self
.
n_heads
*
self
.
head_dim
)
out
=
out
.
view
(
batch
,
patches
,
self
.
n_heads
*
self
.
head_dim
)
attn_output
,
_
=
self
.
o_proj
(
out
)
return
self
.
o_proj
(
out
)
return
attn_output
,
None
class
PixtralHFTransformerBlock
(
nn
.
Module
):
class
PixtralHFTransformerBlock
(
nn
.
Module
):
...
@@ -912,7 +918,7 @@ class PixtralHFTransformerBlock(nn.Module):
...
@@ -912,7 +918,7 @@ class PixtralHFTransformerBlock(nn.Module):
attention_mask
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
=
self
.
attention
.
forward
(
self
.
attention_norm
(
hidden_states
),
r
,
_
=
self
.
attention
.
forward
(
self
.
attention_norm
(
hidden_states
),
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_embeddings
=
position_embeddings
)
position_embeddings
=
position_embeddings
)
h
=
hidden_states
+
r
h
=
hidden_states
+
r
...
@@ -1053,10 +1059,24 @@ class PixtralHFVisionModel(nn.Module):
...
@@ -1053,10 +1059,24 @@ class PixtralHFVisionModel(nn.Module):
# (TODO) Add prefix argument for filtering out weights to be loaded
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[]
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
layer_count
=
len
(
self
.
transformer
.
layers
)
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# omit layers when num_hidden_layers_override is set
if
name
.
startswith
(
"transformer.layers"
):
layer_idx
=
int
(
name
.
split
(
"."
)[
2
])
if
layer_idx
>=
layer_count
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment