Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eddaa2b5
"scripts/vscode:/vscode.git/clone" did not exist on "8db2daa301319323d1096c8390834a692ef5a048"
Unverified
Commit
eddaa2b5
authored
Mar 29, 2024
by
Qubitium
Committed by
GitHub
Mar 28, 2024
Browse files
Add support for new autogptq quant_config.checkpoint_format (#332)
parent
2af565b3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
23 deletions
+28
-23
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+28
-23
No files found.
python/sglang/srt/managers/router/model_runner.py
View file @
eddaa2b5
...
@@ -19,7 +19,11 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
...
@@ -19,7 +19,11 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
QUANTIONCONFIG_MAPPING
=
{
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"marlin"
:
MarlinConfig
}
QUANTIZATION_CONFIG_MAPPING
=
{
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"marlin"
:
MarlinConfig
,
}
logger
=
logging
.
getLogger
(
"model_runner"
)
logger
=
logging
.
getLogger
(
"model_runner"
)
...
@@ -300,30 +304,31 @@ class ModelRunner:
...
@@ -300,30 +304,31 @@ class ModelRunner:
# Load weights
# Load weights
linear_method
=
None
linear_method
=
None
with
_set_default_torch_dtype
(
torch
.
float16
):
with
torch
.
device
(
"cuda"
):
hf_quant_config
=
getattr
(
self
.
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
is
not
None
:
hf_quant_method
=
hf_quant_config
[
"quant_method"
]
# compat: autogptq uses is_marlin_format within quant config
quant_cfg
=
getattr
(
self
.
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
(
if
quant_cfg
is
not
None
:
hf_quant_method
==
"gptq"
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
and
"is_marlin_format"
in
hf_quant_config
# compat: autogptq >=0.8.0 use checkpoint_format: str
and
hf_quant_config
[
"is_marlin_format"
]
# compat: autogptq <=0.7.1 is_marlin_format: bool
):
is_format_marlin
=
quant_cfg
.
get
(
hf_quant_method
=
"marlin"
"checkpoint_format"
quant_config_class
=
QUANTIONCONFIG_MAPPING
.
get
(
hf_quant_method
)
)
==
"marlin"
or
quant_cfg
.
get
(
"is_marlin_format"
,
False
)
# Use marlin if the GPTQ model is serialized in marlin format.
if
quant_method
==
"gptq"
and
is_format_marlin
:
quant_method
=
"marlin"
quant_config_class
=
QUANTIZATION_CONFIG_MAPPING
.
get
(
quant_method
)
if
quant_config_class
is
None
:
if
quant_config_class
is
None
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported quantization method:
{
quant_method
}
"
)
f
"Unsupported quantization method:
{
hf_quant_config
[
'quant_method'
]
}
"
)
quant_config
=
quant_config_class
.
from_config
(
quant_cfg
)
quant_config
=
quant_config_class
.
from_config
(
hf_quant_config
)
logger
.
info
(
f
"quant_config:
{
quant_config
}
"
)
logger
.
info
(
f
"quant_config:
{
quant_config
}
"
)
linear_method
=
quant_config
.
get_linear_method
()
linear_method
=
quant_config
.
get_linear_method
()
with
_set_default_torch_dtype
(
torch
.
float16
):
with
torch
.
device
(
"cuda"
):
model
=
model_class
(
model
=
model_class
(
config
=
self
.
model_config
.
hf_config
,
linear_method
=
linear_method
config
=
self
.
model_config
.
hf_config
,
linear_method
=
linear_method
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment