Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7193774b
Unverified
Commit
7193774b
authored
Sep 25, 2024
by
Michael Goin
Committed by
GitHub
Sep 25, 2024
Browse files
[Misc] Support quantization of MllamaForCausalLM (#8822)
parent
e2c6e0a8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
2 deletions
+9
-2
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+9
-2
No files found.
vllm/model_executor/models/mllama.py
View file @
7193774b
...
@@ -624,6 +624,7 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -624,6 +624,7 @@ class MllamaTextCrossAttention(nn.Module):
self
,
self
,
config
:
Optional
[
config_mllama
.
MllamaTextConfig
]
=
None
,
config
:
Optional
[
config_mllama
.
MllamaTextConfig
]
=
None
,
layer_idx
:
Optional
[
int
]
=
None
,
layer_idx
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -648,12 +649,14 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -648,12 +649,14 @@ class MllamaTextCrossAttention(nn.Module):
self
.
num_heads
,
self
.
num_heads
,
self
.
num_key_value_heads
,
self
.
num_key_value_heads
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
False
,
bias
=
False
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
)
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
# use huggingface's instead
...
@@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
...
@@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
and feedforward."""
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
layer_idx
:
int
)
\
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
layer_idx
:
int
,
quant_config
:
Optional
[
QuantizationConfig
])
\
->
None
:
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
self
.
cross_attn
=
MllamaTextCrossAttention
(
self
.
cross_attn
=
MllamaTextCrossAttention
(
config
=
config
,
config
=
config
,
layer_idx
=
layer_idx
,
layer_idx
=
layer_idx
,
quant_config
=
quant_config
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -725,6 +730,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
...
@@ -725,6 +730,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
)
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -780,7 +786,8 @@ class MllamaTextModel(nn.Module):
...
@@ -780,7 +786,8 @@ class MllamaTextModel(nn.Module):
for
layer_idx
in
range
(
config
.
num_hidden_layers
):
for
layer_idx
in
range
(
config
.
num_hidden_layers
):
if
layer_idx
in
self
.
cross_attention_layers
:
if
layer_idx
in
self
.
cross_attention_layers
:
layers
.
append
(
layers
.
append
(
MllamaCrossAttentionDecoderLayer
(
config
,
layer_idx
))
MllamaCrossAttentionDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
))
else
:
else
:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers
.
append
(
layers
.
append
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment