Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d087bf86
Unverified
Commit
d087bf86
authored
Oct 31, 2024
by
Michael Goin
Committed by
GitHub
Oct 30, 2024
Browse files
[Model] Support quantization of Qwen2VisionTransformer (#9817)
Signed-off-by:
mgoin
<
michael@neuralmagic.com
>
parent
890ca360
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
23 deletions
+35
-23
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+35
-23
No files found.
vllm/model_executor/models/qwen2_vl.py
View file @
d087bf86
...
@@ -126,15 +126,18 @@ class Qwen2VisionMLP(nn.Module):
...
@@ -126,15 +126,18 @@ class Qwen2VisionMLP(nn.Module):
hidden_features
:
int
=
None
,
hidden_features
:
int
=
None
,
act_layer
:
Type
[
nn
.
Module
]
=
QuickGELU
,
act_layer
:
Type
[
nn
.
Module
]
=
QuickGELU
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
hidden_features
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
)
self
.
act
=
act_layer
()
self
.
act
=
act_layer
()
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
in_features
,
in_features
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_parallel
,
_
=
self
.
fc1
(
x
)
x_parallel
,
_
=
self
.
fc1
(
x
)
...
@@ -196,6 +199,7 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -196,6 +199,7 @@ class Qwen2VisionAttention(nn.Module):
num_heads
:
Optional
[
int
]
=
None
,
num_heads
:
Optional
[
int
]
=
None
,
projection_size
:
Optional
[
int
]
=
None
,
projection_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# Per attention head and per partition values.
# Per attention head and per partition values.
...
@@ -207,10 +211,12 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -207,10 +211,12 @@ class Qwen2VisionAttention(nn.Module):
self
.
qkv
=
ColumnParallelLinear
(
input_size
=
embed_dim
,
self
.
qkv
=
ColumnParallelLinear
(
input_size
=
embed_dim
,
output_size
=
3
*
projection_size
,
output_size
=
3
*
projection_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv"
)
self
.
proj
=
RowParallelLinear
(
input_size
=
projection_size
,
self
.
proj
=
RowParallelLinear
(
input_size
=
projection_size
,
output_size
=
embed_dim
,
output_size
=
embed_dim
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
)
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
()
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
()
...
@@ -310,6 +316,7 @@ class Qwen2VisionBlock(nn.Module):
...
@@ -310,6 +316,7 @@ class Qwen2VisionBlock(nn.Module):
act_layer
:
Type
[
nn
.
Module
]
=
QuickGELU
,
act_layer
:
Type
[
nn
.
Module
]
=
QuickGELU
,
norm_layer
:
Type
[
nn
.
Module
]
=
None
,
norm_layer
:
Type
[
nn
.
Module
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
norm_layer
is
None
:
if
norm_layer
is
None
:
...
@@ -321,11 +328,13 @@ class Qwen2VisionBlock(nn.Module):
...
@@ -321,11 +328,13 @@ class Qwen2VisionBlock(nn.Module):
self
.
attn
=
Qwen2VisionAttention
(
embed_dim
=
dim
,
self
.
attn
=
Qwen2VisionAttention
(
embed_dim
=
dim
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
projection_size
=
dim
,
projection_size
=
dim
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
mlp
=
Qwen2VisionMLP
(
dim
,
self
.
mlp
=
Qwen2VisionMLP
(
dim
,
mlp_hidden_dim
,
mlp_hidden_dim
,
act_layer
=
act_layer
,
act_layer
=
act_layer
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
)
->
torch
.
Tensor
:
rotary_pos_emb
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -374,6 +383,7 @@ class Qwen2VisionPatchMerger(nn.Module):
...
@@ -374,6 +383,7 @@ class Qwen2VisionPatchMerger(nn.Module):
norm_layer
:
Type
[
nn
.
Module
]
=
None
,
norm_layer
:
Type
[
nn
.
Module
]
=
None
,
spatial_merge_size
:
int
=
2
,
spatial_merge_size
:
int
=
2
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
...
@@ -384,12 +394,14 @@ class Qwen2VisionPatchMerger(nn.Module):
...
@@ -384,12 +394,14 @@ class Qwen2VisionPatchMerger(nn.Module):
ColumnParallelLinear
(
self
.
hidden_size
,
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
),
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp.0"
),
nn
.
GELU
(),
nn
.
GELU
(),
RowParallelLinear
(
self
.
hidden_size
,
RowParallelLinear
(
self
.
hidden_size
,
d_model
,
d_model
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
),
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp.2"
),
])
])
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -440,6 +452,7 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -440,6 +452,7 @@ class Qwen2VisionTransformer(nn.Module):
vision_config
:
Qwen2VLVisionConfig
,
vision_config
:
Qwen2VLVisionConfig
,
norm_eps
:
float
=
1e-6
,
norm_eps
:
float
=
1e-6
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -467,28 +480,29 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -467,28 +480,29 @@ class Qwen2VisionTransformer(nn.Module):
self
.
rotary_pos_emb
=
Qwen2VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
rotary_pos_emb
=
Qwen2VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
blocks
=
nn
.
ModuleList
([
self
.
blocks
=
nn
.
ModuleList
([
Qwen2VisionBlock
(
Qwen2VisionBlock
(
dim
=
embed_dim
,
dim
=
embed_dim
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
mlp_ratio
=
mlp_ratio
,
norm_layer
=
norm_layer
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
for
_
in
range
(
depth
)
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
depth
)
])
])
self
.
merger
=
Qwen2VisionPatchMerger
(
self
.
merger
=
Qwen2VisionPatchMerger
(
d_model
=
hidden_size
,
d_model
=
hidden_size
,
context_dim
=
embed_dim
,
context_dim
=
embed_dim
,
norm_layer
=
norm_layer
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.merger"
,
)
)
@
property
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
blocks
[
0
].
mlp
.
fc2
.
weight
.
dtype
return
self
.
patch_embed
.
proj
.
weight
.
dtype
@
property
@
property
def
device
(
self
)
->
torch
.
device
:
def
device
(
self
)
->
torch
.
device
:
return
self
.
blocks
[
0
].
mlp
.
fc2
.
weight
.
device
return
self
.
patch_embed
.
proj
.
weight
.
device
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pos_ids
=
[]
pos_ids
=
[]
...
@@ -932,10 +946,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -932,10 +946,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
visual
=
Qwen2VisionTransformer
(
self
.
visual
=
Qwen2VisionTransformer
(
config
.
vision_config
,
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
# NOTE: Qwen2-VL vision encoder does not support any
prefix
=
"visual"
,
# quantization method now.
quant_config
=
None
,
)
)
self
.
model
=
Qwen2Model
(
config
,
self
.
model
=
Qwen2Model
(
config
,
...
@@ -1175,7 +1187,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1175,7 +1187,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
if
"visual"
in
name
and
"qkv.weight"
in
name
:
if
"visual"
in
name
and
name
.
endswith
(
"qkv.weight"
)
:
visual_num_heads
=
self
.
config
.
vision_config
.
num_heads
visual_num_heads
=
self
.
config
.
vision_config
.
num_heads
visual_embed_dim
=
self
.
config
.
vision_config
.
embed_dim
visual_embed_dim
=
self
.
config
.
vision_config
.
embed_dim
head_size
=
visual_embed_dim
//
visual_num_heads
head_size
=
visual_embed_dim
//
visual_num_heads
...
@@ -1184,7 +1196,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1184,7 +1196,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
visual_embed_dim
)
visual_embed_dim
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
visual_embed_dim
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
visual_embed_dim
)
elif
"visual"
in
name
and
"qkv.bias"
in
name
:
elif
"visual"
in
name
and
name
.
endswith
(
"qkv.bias"
)
:
visual_num_heads
=
self
.
config
.
vision_config
.
num_heads
visual_num_heads
=
self
.
config
.
vision_config
.
num_heads
visual_embed_dim
=
self
.
config
.
vision_config
.
embed_dim
visual_embed_dim
=
self
.
config
.
vision_config
.
embed_dim
head_size
=
visual_embed_dim
//
visual_num_heads
head_size
=
visual_embed_dim
//
visual_num_heads
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment