Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6b2d25ef
Unverified
Commit
6b2d25ef
authored
Nov 19, 2024
by
Yan Ma
Committed by
GitHub
Nov 18, 2024
Browse files
[Hardware][XPU] AWQ/GPTQ support for xpu backend (#10107)
Signed-off-by:
yan ma
<
yan.ma@intel.com
>
parent
281cc4b3
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
146 additions
and
52 deletions
+146
-52
docs/source/quantization/supported_hardware.rst
docs/source/quantization/supported_hardware.rst
+4
-4
tests/quantization/test_ipex_quant.py
tests/quantization/test_ipex_quant.py
+6
-4
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+1
-1
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+0
-1
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+4
-0
vllm/model_executor/layers/quantization/ipex_quant.py
vllm/model_executor/layers/quantization/ipex_quant.py
+128
-41
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+3
-1
No files found.
docs/source/quantization/supported_hardware.rst
View file @
6b2d25ef
...
@@ -27,7 +27,7 @@ The table below shows the compatibility of various quantization implementations
...
@@ -27,7 +27,7 @@ The table below shows the compatibility of various quantization implementations
- ✅︎
- ✅︎
- ✅︎
- ✅︎
- ✗
- ✗
-
✗
-
✅︎
- ✅︎
- ✅︎
- ✗
- ✗
- ✗
- ✗
...
@@ -38,8 +38,8 @@ The table below shows the compatibility of various quantization implementations
...
@@ -38,8 +38,8 @@ The table below shows the compatibility of various quantization implementations
- ✅︎
- ✅︎
- ✅︎
- ✅︎
- ✗
- ✗
-
✗
-
✅︎
-
✗
-
✅︎
- ✗
- ✗
- ✗
- ✗
* - Marlin (GPTQ/AWQ/FP8)
* - Marlin (GPTQ/AWQ/FP8)
...
...
tests/quantization/test_ipex_quant.py
View file @
6b2d25ef
"""Test model set-up and inference for quantized HF models supported
"""Test model set-up and inference for quantized HF models supported
on the CPU backend using IPEX (including AWQ).
on the CPU
/GPU
backend using IPEX (including AWQ
/GPTQ
).
Validating the configuration and printing results for manual checking.
Validating the configuration and printing results for manual checking.
...
@@ -11,13 +11,15 @@ import pytest
...
@@ -11,13 +11,15 @@ import pytest
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
MODELS
=
[
MODELS
=
[
"casperhansen/llama-3-8b-instruct-awq"
,
"AMead10/Llama-3.2-1B-Instruct-AWQ"
,
"shuyuej/Llama-3.2-1B-Instruct-GPTQ"
,
# with g_idx
]
]
DTYPE
=
[
"bfloat16"
]
DTYPE
=
[
"bfloat16"
]
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cpu
(),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cpu
()
reason
=
"only supports the CPU backend."
)
and
not
current_platform
.
is_xpu
(),
reason
=
"only supports Intel CPU/XPU backend."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
def
test_ipex_quant
(
vllm_runner
,
model
,
dtype
):
def
test_ipex_quant
(
vllm_runner
,
model
,
dtype
):
...
...
vllm/model_executor/layers/linear.py
View file @
6b2d25ef
...
@@ -27,7 +27,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
...
@@ -27,7 +27,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
,
"IPEXAWQLinearMethod"
"ModelOptFp8LinearMethod"
,
"IPEXAWQLinearMethod"
,
"IPEXGPTQLinearMethod"
]
]
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
6b2d25ef
...
@@ -210,7 +210,6 @@ class GPTQLinearMethod(LinearMethodBase):
...
@@ -210,7 +210,6 @@ class GPTQLinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# for torch.compile
# for torch.compile
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
g_idx
=
Parameter
(
layer
.
g_idx
.
data
,
requires_grad
=
False
)
layer
.
g_idx
=
Parameter
(
layer
.
g_idx
.
data
,
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
6b2d25ef
...
@@ -23,6 +23,7 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
...
@@ -23,6 +23,7 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
PackedColumnParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
PackedvLLMParameter
,
RowvLLMParameter
)
RowvLLMParameter
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -134,6 +135,9 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -134,6 +135,9 @@ class GPTQMarlinConfig(QuantizationConfig):
sym
=
quant_config
.
get
(
"sym"
)
sym
=
quant_config
.
get
(
"sym"
)
desc_act
=
quant_config
.
get
(
"desc_act"
)
desc_act
=
quant_config
.
get
(
"desc_act"
)
if
not
current_platform
.
is_cuda
():
return
False
if
quant_method
!=
"gptq"
:
if
quant_method
!=
"gptq"
:
return
False
return
False
...
...
vllm/model_executor/layers/quantization/ipex_quant.py
View file @
6b2d25ef
...
@@ -2,21 +2,26 @@ from typing import Any, Dict, List, Optional
...
@@ -2,21 +2,26 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.quantization.awq
import
AWQLinearMethod
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.awq
import
(
AWQLinearMethod
,
is_layer_skipped_awq
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
MIN_IPEX_VERSION
=
"2.5.0"
class
IPEXConfig
(
QuantizationConfig
):
class
IPEXConfig
(
QuantizationConfig
):
"""INT8 quantization config class using IPEX for the CPU backend,
"""INT8 quantization config class using IPEX for the CPU
/XPU
backend,
including AWQ.
including AWQ
, GPTQ
.
"""
"""
IPEX_QUANT_METHOD_MAP
=
{
IPEX_QUANT_METHOD_MAP
=
{
"awq"
:
1
,
"awq"
:
1
,
"gptq"
:
2
,
"gptq"
:
0
,
}
}
def
__init__
(
def
__init__
(
...
@@ -24,29 +29,30 @@ class IPEXConfig(QuantizationConfig):
...
@@ -24,29 +29,30 @@ class IPEXConfig(QuantizationConfig):
method
:
str
,
method
:
str
,
weight_bits
:
int
,
weight_bits
:
int
,
group_size
:
int
,
group_size
:
int
,
modules_to_not_convert
:
Optional
[
List
[
str
]]
=
None
,
desc_act
:
Optional
[
bool
]
=
None
,
lm_head_quantized
:
Optional
[
bool
]
=
None
,
)
->
None
:
)
->
None
:
self
.
method
=
method
self
.
method
=
method
self
.
weight_bits
=
weight_bits
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
group_size
=
group_size
self
.
modules_to_not_convert
=
modules_to_not_convert
or
[]
self
.
desc_act
=
desc_act
self
.
lm_head_quantized
=
lm_head_quantized
self
.
pack_factor
=
32
//
self
.
weight_bits
self
.
pack_factor
=
32
//
self
.
weight_bits
if
self
.
weight_bits
not
in
[
4
]:
if
self
.
weight_bits
not
in
[
4
]:
raise
ValueError
(
f
"IPEX quantization supports weight bits [4], "
raise
ValueError
(
f
"IPEX quantization supports weight bits [4], "
f
"but got
{
self
.
weight_bits
}
."
)
f
"but got
{
self
.
weight_bits
}
."
)
if
self
.
method
==
"awq"
:
if
self
.
method
not
in
[
"awq"
,
"gptq"
]:
self
.
quant_method
=
IPEXAWQLinearMethod
raise
ValueError
(
f
"IPEX quantization supports [awq, gptq], "
else
:
raise
ValueError
(
f
"IPEX quantization supports [awq], "
f
"but got
{
self
.
method
}
."
)
f
"but got
{
self
.
method
}
."
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"IPEXConfig(method=
{
self
.
method
}
"
return
(
f
"IPEXConfig(method=
{
self
.
method
}
,
"
f
"weight_bits=
{
self
.
weight_bits
}
, "
f
"weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
"
)
f
"group_size=
{
self
.
group_size
}
)"
)
def
get_ipex_quant_method_id
(
self
)
->
int
:
return
IPEXConfig
.
IPEX_QUANT_METHOD_MAP
[
self
.
method
]
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
@@ -70,19 +76,32 @@ class IPEXConfig(QuantizationConfig):
...
@@ -70,19 +76,32 @@ class IPEXConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"IPEXConfig"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"IPEXConfig"
:
method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
]).
lower
()
method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
]).
lower
()
if
method
==
"awq"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"w_bit"
,
"bits"
])
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"w_bit"
,
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"q_group_size"
,
"group_size"
])
group_size
=
cls
.
get_from_keys
(
config
,
return
cls
(
method
,
weight_bits
,
group_size
)
[
"q_group_size"
,
"group_size"
])
modules_to_not_convert
=
cls
.
get_from_keys_or
(
config
,
[
"modules_to_not_convert"
],
None
)
return
cls
(
method
,
weight_bits
,
group_size
,
modules_to_not_convert
,
False
,
False
)
# otherwise for gptq
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
desc_act
=
cls
.
get_from_keys_or
(
config
,
[
"desc_act"
],
default
=
False
)
return
cls
(
method
,
weight_bits
,
group_size
,
[],
desc_act
,
lm_head_quantized
)
@
classmethod
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
user_quant
)
->
Optional
[
str
]:
if
not
current_platform
.
is_cpu
():
if
not
current_platform
.
is_cpu
()
and
not
current_platform
.
is_xpu
()
:
return
None
return
None
quant_method
=
hf_quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
quant_method
=
hf_quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
if
quant_method
in
[
"awq"
]:
if
quant_method
in
[
"awq"
,
"gptq"
]:
return
cls
.
get_name
()
return
cls
.
get_name
()
return
None
return
None
...
@@ -90,12 +109,81 @@ class IPEXConfig(QuantizationConfig):
...
@@ -90,12 +109,81 @@ class IPEXConfig(QuantizationConfig):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"LinearMethodBase"
]:
prefix
:
str
)
->
Optional
[
"LinearMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
self
.
quant_method
(
self
)
if
self
.
method
==
"awq"
:
if
is_layer_skipped_awq
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
return
IPEXAWQLinearMethod
(
self
)
if
self
.
method
==
"gptq"
:
return
IPEXGPTQLinearMethod
(
self
)
return
None
return
None
class
IPEXGPTQLinearMethod
(
GPTQLinearMethod
):
"""GPTQ linear method using IPEX for the CPU/XPU backend.
"""
def
__init__
(
self
,
quant_config
:
IPEXConfig
):
self
.
quant_config
=
quant_config
# type: ignore
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
bias
=
layer
.
bias
if
not
layer
.
skip_bias_add
else
None
try
:
import
intel_extension_for_pytorch
as
ipex
if
ipex
.
__version__
<
MIN_IPEX_VERSION
:
raise
ImportError
(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f
"intel_extension_for_pytorch>=
{
MIN_IPEX_VERSION
}
."
)
except
ImportError
as
err
:
raise
ImportError
(
"Please install "
f
"intel_extension_for_pytorch>=
{
MIN_IPEX_VERSION
}
via "
f
"`pip install intel_extension_for_pytorch>=
{
MIN_IPEX_VERSION
}
`"
" to use IPEX-AWQ linear method."
)
from
err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# with better performance.
lowp_mode
=
ipex
.
quantization
.
WoqLowpMode
.
INT8
# The weight will be de-packed from INT4 to INT8.
weight_dtype
=
ipex
.
quantization
.
WoqWeightDtype
.
INT4
# The float activation will be quantized (dynamic, per-token) to INT8.
act_quant_mode
=
ipex
.
quantization
.
WoqActQuantMode
.
PER_BATCH_IC_BLOCK
qconfig
=
ipex
.
quantization
.
get_weight_only_quant_qconfig_mapping
(
weight_dtype
=
weight_dtype
,
lowp_mode
=
lowp_mode
,
act_quant_mode
=
act_quant_mode
,
group_size
=
self
.
quant_config
.
group_size
,
)
layer
.
ipex_output_size
=
layer
.
qweight
.
shape
[
-
1
]
g_idx
=
layer
.
g_idx
if
self
.
quant_config
.
desc_act
else
None
layer
.
ipex_qlinear
=
ipex
.
llm
.
quantization
.
woq_linear
.
\
IPEXWeightOnlyQuantizedLinear
.
from_weight
(
layer
.
qweight
,
layer
.
scales
,
layer
.
qzeros
,
layer
.
qweight
.
size
(
0
),
layer
.
ipex_output_size
,
qconfig
=
qconfig
,
g_idx
=
g_idx
,
bias
=
bias
,
group_size
=
self
.
quant_config
.
group_size
,
quant_method
=
IPEXConfig
.
IPEX_QUANT_METHOD_MAP
[
"gptq"
]
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
layer
.
ipex_qlinear
(
reshaped_x
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
.
reshape
(
x
.
shape
[:
-
1
]
+
(
layer
.
ipex_output_size
,
))
class
IPEXAWQLinearMethod
(
AWQLinearMethod
):
class
IPEXAWQLinearMethod
(
AWQLinearMethod
):
"""AWQ linear method using IPEX for the CPU backend.
"""AWQ linear method using IPEX for the CPU
/XPU
backend.
"""
"""
def
__init__
(
self
,
quant_config
:
IPEXConfig
):
def
__init__
(
self
,
quant_config
:
IPEXConfig
):
...
@@ -108,15 +196,16 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
...
@@ -108,15 +196,16 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
try
:
try
:
import
intel_extension_for_pytorch
as
ipex
import
intel_extension_for_pytorch
as
ipex
if
ipex
.
__version__
<
"2.4.0"
:
if
ipex
.
__version__
<
MIN_IPEX_VERSION
:
raise
ImportError
(
"intel_extension_for_pytorch version is "
raise
ImportError
(
"intel_extension_for_pytorch version is "
"wrong. Please install "
"wrong. Please install "
"intel_extension_for_pytorch>=
2.4.0
."
)
f
"intel_extension_for_pytorch>=
{
MIN_IPEX_VERSION
}
."
)
except
ImportError
as
err
:
except
ImportError
as
err
:
raise
ImportError
(
raise
ImportError
(
"Please install "
"Please install "
"intel_extension_for_pytorch>=
2.4.0
via "
f
"intel_extension_for_pytorch>=
{
MIN_IPEX_VERSION
}
via "
"`pip install intel_extension_for_pytorch>=
2.4.0
`"
f
"`pip install intel_extension_for_pytorch>=
{
MIN_IPEX_VERSION
}
`"
" to use IPEX-AWQ linear method."
)
from
err
" to use IPEX-AWQ linear method."
)
from
err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
...
@@ -136,8 +225,8 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
...
@@ -136,8 +225,8 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
layer
.
ipex_output_size
=
layer
.
qweight
.
size
(
layer
.
ipex_output_size
=
layer
.
qweight
.
size
(
1
)
*
self
.
quant_config
.
pack_factor
1
)
*
self
.
quant_config
.
pack_factor
layer
.
ipex_qlinear
=
ipex
.
nn
.
modules
.
weight_only_
quantization
.
\
layer
.
ipex_qlinear
=
ipex
.
llm
.
quantization
.
woq_linear
.
\
WeightOnlyQuantizedLinear
.
from_weight
(
IPEX
WeightOnlyQuantizedLinear
.
from_weight
(
layer
.
qweight
,
layer
.
qweight
,
layer
.
scales
,
layer
.
scales
,
layer
.
qzeros
,
layer
.
qzeros
,
...
@@ -146,8 +235,7 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
...
@@ -146,8 +235,7 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
qconfig
=
qconfig
,
qconfig
=
qconfig
,
bias
=
bias
,
bias
=
bias
,
group_size
=
self
.
quant_config
.
group_size
,
group_size
=
self
.
quant_config
.
group_size
,
quant_method
=
quant_method
=
IPEXConfig
.
IPEX_QUANT_METHOD_MAP
[
"awq"
]
# type: ignore
self
.
quant_config
.
get_ipex_quant_method_id
()
# type: ignore
)
)
def
apply
(
self
,
def
apply
(
self
,
...
@@ -156,5 +244,4 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
...
@@ -156,5 +244,4 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
layer
.
ipex_qlinear
(
reshaped_x
)
out
=
layer
.
ipex_qlinear
(
reshaped_x
)
return
out
.
reshape
(
x
.
shape
[:
-
1
]
+
(
layer
.
ipex_output_size
,
))
return
out
.
reshape
(
x
.
shape
[:
-
1
]
+
(
layer
.
ipex_output_size
,
))
vllm/model_executor/model_loader/loader.py
View file @
6b2d25ef
...
@@ -29,6 +29,8 @@ from vllm.envs import VLLM_USE_MODELSCOPE
...
@@ -29,6 +29,8 @@ from vllm.envs import VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ReplicatedLinear
,
from
vllm.model_executor.layers.linear
import
(
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizeMethodBase
)
from
vllm.model_executor.model_loader.tensorizer
import
(
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
is_vllm_tensorized
,
load_with_tensorizer
,
TensorizerConfig
,
is_vllm_tensorized
,
load_with_tensorizer
,
serialize_vllm_model
,
tensorizer_weights_iterator
)
serialize_vllm_model
,
tensorizer_weights_iterator
)
...
@@ -348,7 +350,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -348,7 +350,7 @@ class DefaultModelLoader(BaseModelLoader):
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
if
isinstance
(
quant_method
,
QuantizeMethodBase
)
:
# When quant methods need to process weights after loading
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# to be on the global target device. This scope is for the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment