Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
10a02535
Unverified
Commit
10a02535
authored
Aug 09, 2025
by
Eldar Kurtić
Committed by
GitHub
Aug 08, 2025
Browse files
Fix loading of quantized BigCode models (#22463)
Signed-off-by:
Eldar Kurtic
<
eldar@neuralmagic.com
>
parent
65552b47
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
4 deletions
+14
-4
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+14
-4
No files found.
vllm/model_executor/models/gpt_bigcode.py
View file @
10a02535
...
@@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors
...
@@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
GPTBigCodeAttention
(
nn
.
Module
):
class
GPTBigCodeAttention
(
nn
.
Module
):
...
@@ -83,6 +84,7 @@ class GPTBigCodeAttention(nn.Module):
...
@@ -83,6 +84,7 @@ class GPTBigCodeAttention(nn.Module):
total_num_kv_heads
,
total_num_kv_heads
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_attn"
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
...
@@ -90,6 +92,7 @@ class GPTBigCodeAttention(nn.Module):
...
@@ -90,6 +92,7 @@ class GPTBigCodeAttention(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
)
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -123,6 +126,7 @@ class GPTBigMLP(nn.Module):
...
@@ -123,6 +126,7 @@ class GPTBigMLP(nn.Module):
intermediate_size
:
int
,
intermediate_size
:
int
,
config
:
GPTBigCodeConfig
,
config
:
GPTBigCodeConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -131,12 +135,14 @@ class GPTBigMLP(nn.Module):
...
@@ -131,12 +135,14 @@ class GPTBigMLP(nn.Module):
intermediate_size
,
intermediate_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_fc"
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
...
@@ -167,7 +173,10 @@ class GPTBigCodeBlock(nn.Module):
...
@@ -167,7 +173,10 @@ class GPTBigCodeBlock(nn.Module):
quant_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPTBigMLP
(
inner_dim
,
config
,
quant_config
)
self
.
mlp
=
GPTBigMLP
(
inner_dim
,
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -260,7 +269,7 @@ class GPTBigCodeModel(nn.Module):
...
@@ -260,7 +269,7 @@ class GPTBigCodeModel(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if
"c_attn.input_scale"
in
name
or
"c_attn.weight_scale"
in
name
:
if
"c_attn.input_scale"
in
name
:
weight_loader
(
param
,
loaded_weight
,
'q'
)
weight_loader
(
param
,
loaded_weight
,
'q'
)
weight_loader
(
param
,
loaded_weight
,
'k'
)
weight_loader
(
param
,
loaded_weight
,
'k'
)
weight_loader
(
param
,
loaded_weight
,
'v'
)
weight_loader
(
param
,
loaded_weight
,
'v'
)
...
@@ -284,7 +293,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -284,7 +293,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
vllm_config
=
vllm_config
,
self
.
transformer
=
GPTBigCodeModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
self
.
lm_head
=
self
.
transformer
.
wte
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment