Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d2051cc
Commit
6d2051cc
authored
Oct 21, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev
parents
2c7f740a
a2c71c54
Changes
457
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
680 additions
and
178 deletions
+680
-178
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+199
-9
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+2
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+15
-8
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+8
-8
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+1
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+1
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+53
-5
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+1
-2
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+2
-100
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+25
-11
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+6
-6
vllm/model_executor/layers/quantization/ipex_quant.py
vllm/model_executor/layers/quantization/ipex_quant.py
+166
-0
vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py
...el_executor/layers/quantization/kernels/MPLinearKernel.py
+4
-0
vllm/model_executor/layers/quantization/kernels/__init__.py
vllm/model_executor/layers/quantization/kernels/__init__.py
+5
-3
vllm/model_executor/layers/quantization/kernels/exllama.py
vllm/model_executor/layers/quantization/kernels/exllama.py
+140
-0
vllm/model_executor/layers/quantization/kernels/machete.py
vllm/model_executor/layers/quantization/kernels/machete.py
+7
-7
vllm/model_executor/layers/quantization/kernels/marlin.py
vllm/model_executor/layers/quantization/kernels/marlin.py
+5
-4
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+15
-0
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+6
-6
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+19
-3
No files found.
Too many changes to show.
To preserve performance only
457 of 457+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
6d2051cc
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_moe_permute_scales
,
marlin_permute_scales
,
moe_awq_to_marlin_zero_points
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
PackedvLLMParameter
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -35,12 +41,13 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -35,12 +41,13 @@ class AWQMarlinConfig(QuantizationConfig):
self
.
group_size
=
group_size
self
.
group_size
=
group_size
self
.
has_zp
=
has_zp
self
.
has_zp
=
has_zp
self
.
lm_head_quantized
=
lm_head_quantized
self
.
lm_head_quantized
=
lm_head_quantized
self
.
weight_bits
=
weight_bits
if
weight_bits
not
in
self
.
TYPE_MAP
:
if
self
.
weight_bits
not
in
self
.
TYPE_MAP
:
raise
ValueError
(
f
"Unsupported num_bits =
{
weight_bits
}
. "
raise
ValueError
(
f
"Unsupported num_bits =
{
self
.
weight_bits
}
. "
f
"Supported num_bits =
{
self
.
TYPE_MAP
.
keys
()
}
"
)
f
"Supported num_bits =
{
self
.
TYPE_MAP
.
keys
()
}
"
)
self
.
quant_type
=
self
.
TYPE_MAP
[
weight_bits
]
self
.
quant_type
=
self
.
TYPE_MAP
[
self
.
weight_bits
]
verify_marlin_supported
(
self
.
quant_type
,
verify_marlin_supported
(
self
.
quant_type
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
...
@@ -98,10 +105,12 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -98,10 +105,12 @@ class AWQMarlinConfig(QuantizationConfig):
return
None
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"
AWQMarlinLinear
Method"
]:
prefix
:
str
)
->
Optional
[
"
Quantize
Method
Base
"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
AWQMarlinLinearMethod
(
self
)
return
AWQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
AWQMoEMethod
(
self
)
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
@@ -115,6 +124,9 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -115,6 +124,9 @@ class AWQMarlinConfig(QuantizationConfig):
group_size
=
quant_config
.
get
(
"group_size"
)
group_size
=
quant_config
.
get
(
"group_size"
)
has_zp
=
quant_config
.
get
(
"zero_point"
)
has_zp
=
quant_config
.
get
(
"zero_point"
)
if
not
current_platform
.
is_cuda
():
return
False
if
quant_method
!=
"awq"
:
if
quant_method
!=
"awq"
:
return
False
return
False
...
@@ -271,4 +283,182 @@ class AWQMarlinLinearMethod(LinearMethodBase):
...
@@ -271,4 +283,182 @@ class AWQMarlinLinearMethod(LinearMethodBase):
quant_type
=
self
.
quant_config
.
quant_type
,
quant_type
=
self
.
quant_config
.
quant_type
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
bias
=
bias
)
bias
=
bias
)
\ No newline at end of file
class
AWQMoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
AWQMarlinConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
extra_weight_attrs
.
update
({
"is_transposed"
:
True
,
"quant_method"
:
FusedMoeWeightScaleSupported
.
GROUP
.
value
,
})
w13_qweight
=
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
2
*
intermediate_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_qweight"
,
w13_qweight
)
set_weight_attrs
(
w13_qweight
,
extra_weight_attrs
)
w2_qweight
=
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_qweight"
,
w2_qweight
)
set_weight_attrs
(
w2_qweight
,
extra_weight_attrs
)
num_groups_w13
=
hidden_size
//
self
.
quant_config
.
group_size
num_groups_w2
=
intermediate_size
//
self
.
quant_config
.
group_size
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales
=
Parameter
(
torch
.
empty
(
num_experts
,
num_groups_w13
,
intermediate_size
*
2
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scales"
,
w13_scales
)
set_weight_attrs
(
w13_scales
,
extra_weight_attrs
)
w2_scales
=
Parameter
(
torch
.
empty
(
num_experts
,
num_groups_w2
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
set_weight_attrs
(
w2_scales
,
extra_weight_attrs
)
# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros
=
Parameter
(
torch
.
empty
(
num_experts
,
num_groups_w13
,
2
*
intermediate_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_qzeros"
,
w13_qzeros
)
set_weight_attrs
(
w13_qzeros
,
extra_weight_attrs
)
w2_qzeros
=
Parameter
(
torch
.
empty
(
num_experts
,
num_groups_w2
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
num_experts
=
layer
.
w13_qweight
.
shape
[
0
]
device
=
layer
.
w13_qweight
.
device
layer
.
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
marlin_w13_qweight
=
ops
.
awq_marlin_moe_repack
(
layer
.
w13_qweight
,
layer
.
w13_g_idx_sort_indices
,
size_k
=
layer
.
w13_qweight
.
shape
[
1
],
size_n
=
layer
.
w13_qweight
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
num_bits
=
self
.
quant_config
.
weight_bits
,
)
replace_parameter
(
layer
,
"w13_qweight"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
awq_marlin_moe_repack
(
layer
.
w2_qweight
,
layer
.
w2_g_idx_sort_indices
,
size_k
=
layer
.
w2_qweight
.
shape
[
1
],
size_n
=
layer
.
w2_qweight
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
num_bits
=
self
.
quant_config
.
weight_bits
,
)
replace_parameter
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# Why does this take the intermediate size for size_k?
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_n
=
layer
.
w13_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_parameter
(
layer
,
"w13_scales"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w2_scales
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_n
=
layer
.
w2_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_parameter
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
marlin_w13_zp
=
moe_awq_to_marlin_zero_points
(
layer
.
w13_qzeros
,
size_k
=
layer
.
w13_qzeros
.
shape
[
1
],
size_n
=
layer
.
w13_qzeros
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_parameter
(
layer
,
"w13_qzeros"
,
marlin_w13_zp
)
marlin_w2_zp
=
moe_awq_to_marlin_zero_points
(
layer
.
w2_qzeros
,
size_k
=
layer
.
w2_qzeros
.
shape
[
1
],
size_n
=
layer
.
w2_qzeros
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_parameter
(
layer
,
"w2_qzeros"
,
marlin_w2_zp
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
return
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
layer
.
w13_scales
,
layer
.
w2_scales
,
router_logits
,
topk_weights
,
topk_ids
,
w1_zeros
=
layer
.
w13_qzeros
,
w2_zeros
=
layer
.
w2_qzeros
,
num_bits
=
self
.
quant_config
.
weight_bits
,
)
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
6d2051cc
...
@@ -108,7 +108,7 @@ class BitsAndBytesConfig(QuantizationConfig):
...
@@ -108,7 +108,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[
"gelu"
,
"gelu_fast"
,
"gelu_new"
,
"gelu_pytorch_tanh"
]
return
[]
class
BitsAndBytesLinearMethod
(
LinearMethodBase
):
class
BitsAndBytesLinearMethod
(
LinearMethodBase
):
...
@@ -236,7 +236,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
...
@@ -236,7 +236,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
if
generation
==
0
or
generation
==
1
:
if
generation
==
0
or
generation
==
1
:
matmul_states
[
i
]
=
MatmulLtState
()
matmul_states
[
i
]
=
MatmulLtState
()
matmul_states
[
i
].
CB
=
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]]
matmul_states
[
i
].
CB
=
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]]
matmul_states
[
i
].
SCB
=
quant_states
[
i
]
matmul_states
[
i
].
SCB
=
quant_states
[
i
]
.
to
(
x
.
device
)
matmul_states
[
i
].
threshold
=
(
matmul_states
[
i
].
threshold
=
(
self
.
quant_config
.
llm_int8_threshold
)
self
.
quant_config
.
llm_int8_threshold
)
matmul_states
[
i
].
has_fp16_weights
=
(
matmul_states
[
i
].
has_fp16_weights
=
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
6d2051cc
from
typing
import
Any
,
Dict
,
List
,
Optional
,
cast
from
typing
import
Any
,
Dict
,
List
,
Optional
,
cast
import
torch
import
torch
from
compressed_tensors.config
import
CompressionFormat
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
)
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
...
@@ -16,8 +20,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...
@@ -16,8 +20,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
find_matched_target
,
is_activation_quantization_format
,
QuantizationType
,
find_matched_target
,
is_activation_quantization_format
,
should_ignore_layer
)
should_ignore_layer
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -138,10 +141,11 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -138,10 +141,11 @@ class CompressedTensorsConfig(QuantizationConfig):
or
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
)
or
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
)
is_tensor
=
(
weight_strategy
and
input_quant
.
strategy
is_tensor
=
(
weight_strategy
and
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
==
QuantizationStrategy
.
TENSOR
.
value
)
is_symmetric
=
weight_quant
.
symmetric
and
input_quant
.
symmetric
is_static
=
not
weight_quant
.
dynamic
and
not
input_quant
.
dynamic
is_static
=
not
weight_quant
.
dynamic
and
not
input_quant
.
dynamic
return
is_8_bits
and
is_tensor
and
is_symmetric
and
is_static
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return
is_8_bits
and
is_tensor
and
weight_quant
.
symmetric
and
is_static
def
_is_dynamic_token_w8a8
(
self
,
weight_quant
:
BaseModel
,
def
_is_dynamic_token_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
...
@@ -151,10 +155,11 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -151,10 +155,11 @@ class CompressedTensorsConfig(QuantizationConfig):
or
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
)
or
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
)
is_token
=
(
weight_strategy
and
input_quant
.
strategy
is_token
=
(
weight_strategy
and
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
.
value
)
==
QuantizationStrategy
.
TOKEN
.
value
)
is_symmetric
=
weight_quant
.
symmetric
and
input_quant
.
symmetric
is_dynamic
=
not
weight_quant
.
dynamic
and
input_quant
.
dynamic
is_dynamic
=
not
weight_quant
.
dynamic
and
input_quant
.
dynamic
return
is_8_bits
and
is_token
and
is_symmetric
and
is_dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return
is_8_bits
and
is_token
and
weight_quant
.
symmetric
and
is_dynamic
def
_is_fp8_w8a8
(
self
,
weight_quant
:
BaseModel
,
def
_is_fp8_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
...
@@ -265,12 +270,14 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -265,12 +270,14 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8
(
return
CompressedTensorsW8A8Int8
(
strategy
=
weight_quant
.
strategy
,
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
True
)
is_static_input_scheme
=
True
,
input_symmetric
=
input_quant
.
symmetric
)
if
self
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8
(
return
CompressedTensorsW8A8Int8
(
strategy
=
weight_quant
.
strategy
,
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
False
)
is_static_input_scheme
=
False
,
input_symmetric
=
input_quant
.
symmetric
)
raise
NotImplementedError
(
raise
NotImplementedError
(
"No compressed-tensors compatible scheme was found."
)
"No compressed-tensors compatible scheme was found."
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
6d2051cc
...
@@ -3,14 +3,14 @@ from enum import Enum
...
@@ -3,14 +3,14 @@ from enum import Enum
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors.quantization
import
QuantizationStrategy
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
)
WNA16_SUPPORTED_BITS
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -498,14 +498,14 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -498,14 +498,14 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x
,
x
,
layer
.
w13_weight_packed
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
layer
.
w2_weight_packed
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
w1_scale
=
layer
.
w13_weight_scale
,
g_idx1
=
layer
.
w13_g_idx
,
w2_scale
=
layer
.
w2_weight_scale
,
g_idx2
=
layer
.
w2_g_idx
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
num_bits
=
self
.
num_bits
,
num_bits
=
self
.
num_bits
,
)
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
View file @
6d2051cc
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
6d2051cc
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
apply_fp8_linear
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
requantize_with_max_scale
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
6d2051cc
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_int8_linear
,
convert_to_channelwise
)
apply_int8_linear
,
convert_to_channelwise
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
...
@@ -14,12 +14,16 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
...
@@ -14,12 +14,16 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
ModelWeightParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
logger
=
init_logger
(
__name__
)
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
,
input_symmetric
:
bool
):
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
input_symmetric
=
input_symmetric
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
...
@@ -46,10 +50,43 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -46,10 +50,43 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
requires_grad
=
False
)
requires_grad
=
False
)
# INPUT SCALE
# INPUT SCALE
if
self
.
is_static_input_scheme
:
if
self
.
is_static_input_scheme
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
if
self
.
input_symmetric
:
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
layer
.
input_zero_point
=
None
else
:
# reconstruct the ranges
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
azps
=
layer
.
input_zero_point
.
to
(
dtype
=
torch
.
int32
)
range_max
=
(
layer
.
input_scale
*
(
int8_traits
.
max
-
azps
)).
max
()
range_min
=
(
layer
.
input_scale
*
(
int8_traits
.
min
-
azps
)).
min
()
scale
=
(
range_max
-
range_min
)
/
(
int8_traits
.
max
-
int8_traits
.
min
)
layer
.
input_scale
=
Parameter
(
scale
,
requires_grad
=
False
)
# AZP loaded as int8 but used as int32
azp
=
(
int8_traits
.
min
-
range_min
/
scale
).
to
(
dtype
=
torch
.
int32
)
layer
.
input_zero_point
=
Parameter
(
azp
,
requires_grad
=
False
)
else
:
else
:
layer
.
input_scale
=
None
layer
.
input_scale
=
None
layer
.
input_zero_point
=
None
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if
not
self
.
input_symmetric
:
layer
.
azp_adj
=
layer
.
weight
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
else
:
layer
.
azp_adj
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
output_partition_sizes
:
List
[
int
],
...
@@ -90,6 +127,15 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -90,6 +127,15 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
if
not
self
.
input_symmetric
:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
...
@@ -97,4 +143,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -97,4 +143,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
input_scale
=
layer
.
input_scale
,
input_zero_point
=
layer
.
input_zero_point
,
azp_adj
=
layer
.
azp_adj
,
bias
=
bias
)
bias
=
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
6d2051cc
from
typing
import
Callable
,
List
,
Optional
,
Set
from
typing
import
Callable
,
List
,
Optional
,
Set
import
torch
import
torch
from
compressed_tensors.quantization
import
ActivationOrdering
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
ActivationOrdering
)
from
vllm.model_executor.layers.quantization.kernels
import
(
from
vllm.model_executor.layers.quantization.kernels
import
(
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
6d2051cc
import
re
import
re
from
enum
import
Enum
from
typing
import
Iterable
,
Optional
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
,
field_validator
from
compressed_tensors
import
CompressionFormat
from
torch.nn
import
Module
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
FUSED_LAYER_NAME_MAPPING
)
FUSED_LAYER_NAME_MAPPING
)
class
CompressionFormat
(
Enum
):
dense
=
"dense"
sparse_bitmask
=
"sparse-bitmask"
naive_quantized
=
"naive-quantized"
float_quantized
=
"float-quantized"
int_quantized
=
"int-quantized"
pack_quantized
=
"pack-quantized"
marlin_24
=
"marlin-24"
class
QuantizationType
(
str
,
Enum
):
"""
Enum storing quantization type options
"""
INT
=
"int"
FLOAT
=
"float"
class
QuantizationStrategy
(
str
,
Enum
):
"""
Enum storing quantization strategy options
"""
TENSOR
=
"tensor"
CHANNEL
=
"channel"
GROUP
=
"group"
BLOCK
=
"block"
TOKEN
=
"token"
class
ActivationOrdering
(
str
,
Enum
):
"""
Enum storing strategies for activation ordering
Group: reorder groups and weight
\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder
\n
"""
GROUP
=
"group"
WEIGHT
=
"weight"
class
QuantizationArgs
(
BaseModel
):
"""
User facing arguments used to define a quantization config
for weights or activations
:param num_bits: quantization bit depth
:param type: dtype to quantized to, either int or float
:param symmetric: whether or not quantization scale is symmetric
:param strategy: string determining the scope of scale/zero-point to apply
:param group_size: group length to use for the group strategy
:param block_structure: 2d block structure to use for the block
strategy, must be of the format "2x4", "8x16", etc.
:param dynamic: set True to perform dynamic quantization -
values will not be calibrated during calibration phase,
instead during inference new quantization ranges will be
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to None for arbitrary ordering
"""
num_bits
:
int
=
8
type
:
QuantizationType
=
QuantizationType
.
INT
symmetric
:
bool
=
True
group_size
:
Optional
[
int
]
=
None
strategy
:
Optional
[
QuantizationStrategy
]
=
None
block_structure
:
Optional
[
str
]
=
None
dynamic
:
bool
=
False
actorder
:
Union
[
ActivationOrdering
,
bool
,
None
]
=
None
observer
:
str
=
Field
(
default
=
"minmax"
,
description
=
(
"The class to use to compute the quantization param - "
"scale and zero-point'"
),
)
observer_kwargs
:
Dict
[
str
,
Any
]
=
Field
(
default_factory
=
dict
,
description
=
(
"optional dict of kwargs to be passed directly to torch quantization "
"Observers constructor excluding quantization range or symmetry"
),
)
@
field_validator
(
"actorder"
,
mode
=
"before"
)
def
validate_actorder
(
cls
,
value
)
->
Optional
[
ActivationOrdering
]:
if
isinstance
(
value
,
bool
):
return
ActivationOrdering
.
GROUP
if
value
else
None
if
isinstance
(
value
,
str
):
return
ActivationOrdering
(
value
.
lower
())
return
value
def
is_activation_quantization_format
(
format
:
str
)
->
bool
:
def
is_activation_quantization_format
(
format
:
str
)
->
bool
:
_ACTIVATION_QUANTIZATION_FORMATS
=
[
_ACTIVATION_QUANTIZATION_FORMATS
=
[
CompressionFormat
.
naive_quantized
.
value
,
CompressionFormat
.
naive_quantized
.
value
,
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
6d2051cc
...
@@ -86,15 +86,16 @@ class GGUFLinearMethod(LinearMethodBase):
...
@@ -86,15 +86,16 @@ class GGUFLinearMethod(LinearMethodBase):
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
tensor_shape
=
(
output_size_per_partition
,
input_size_per_partition
)
tensor_shape
=
(
output_size_per_partition
,
input_size_per_partition
)
qweight
=
UninitializedParameter
(
requires_grad
=
False
)
qweight
=
GGUF
UninitializedParameter
(
requires_grad
=
False
)
set_weight_attrs
(
set_weight_attrs
(
qweight
,
{
qweight
,
{
"input_dim"
:
1
,
"input_dim"
:
1
,
"output_dim"
:
0
,
"output_dim"
:
0
,
"tensor_shape"
:
tensor_shape
,
"tensor_shape"
:
tensor_shape
,
"is_gguf_weight"
:
True
,
"is_gguf_weight"
:
True
,
"
shard_size
"
:
{}
,
"
data_container
"
:
[]
,
"shard_id"
:
[],
"shard_id"
:
[],
"shard_id_map"
:
{},
})
})
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
...
@@ -116,21 +117,17 @@ class GGUFLinearMethod(LinearMethodBase):
...
@@ -116,21 +117,17 @@ class GGUFLinearMethod(LinearMethodBase):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
shard_size
=
getattr
(
layer
.
qweight
,
"shard_size"
,
None
)
shard_id
=
getattr
(
layer
.
qweight
,
"shard_id"
,
None
)
shard_id
=
getattr
(
layer
.
qweight
,
"shard_id"
,
None
)
if
shard_id
and
shard_size
:
if
shard_id
:
result
=
[]
offset
=
0
# dequantize shard weights respectively
# dequantize shard weights respectively
shard_id
=
[
"q"
,
"k"
,
"v"
]
if
"q"
in
shard_id
else
shard_id
shard_id
=
[
"q"
,
"k"
,
"v"
]
if
"q"
in
shard_id
else
shard_id
qweight
=
layer
.
qweight
.
unbind
(
0
)
result
=
[]
for
id
in
shard_id
:
for
id
in
shard_id
:
shard_weight
=
layer
.
qweight
[
q_idx
=
layer
.
qweight
.
shard_id_map
[
id
]
offset
:
offset
+
shard_size
[
id
][
0
],
:
shard_size
[
id
][
1
]].
contiguous
()
qweight_type
=
layer
.
qweight_type
.
shard_weight_type
[
id
]
qweight_type
=
layer
.
qweight_type
.
shard_weight_type
[
id
]
result
.
append
(
_fuse_mul_mat
(
x
,
shard_weight
,
qweight_type
))
result
.
append
(
_fuse_mul_mat
(
x
,
qweight
[
q_idx
],
qweight_type
))
offset
+=
shard_size
[
id
][
0
]
out
=
torch
.
cat
(
result
,
axis
=
1
)
out
=
torch
.
cat
(
result
,
axis
=
1
)
else
:
else
:
qweight
=
layer
.
qweight
qweight
=
layer
.
qweight
...
@@ -162,3 +159,20 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
...
@@ -162,3 +159,20 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
dequant
=
ops
.
ggml_dequantize
(
quant
,
qweight_type
,
hidden_size
,
dequant
=
ops
.
ggml_dequantize
(
quant
,
qweight_type
,
hidden_size
,
x_flat
.
shape
[
0
])
x_flat
.
shape
[
0
])
return
dequant
.
view
(
*
x
.
shape
,
hidden_size
)
return
dequant
.
view
(
*
x
.
shape
,
hidden_size
)
class
GGUFUninitializedParameter
(
UninitializedParameter
):
cls_to_become
=
Parameter
data_container
:
List
[
torch
.
Tensor
]
def
materialize_nested
(
self
)
->
Parameter
:
nested_data
=
torch
.
nested
.
nested_tensor
(
self
.
data_container
,
device
=
self
.
device
,
dtype
=
torch
.
uint8
)
self
.
data_container
.
clear
()
param
=
torch
.
Tensor
.
_make_subclass
(
self
.
cls_to_become
,
nested_data
,
require_grad
=
False
)
for
k
,
v
in
self
.
__dict__
.
items
():
setattr
(
param
,
k
,
v
)
return
param
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
6d2051cc
...
@@ -557,14 +557,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -557,14 +557,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x
,
x
,
layer
.
w13_qweight
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
layer
.
w2_qweight
,
layer
.
w13_scales
,
layer
.
w2_scales
,
router_logits
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
w1_scale
=
layer
.
w13_scales
,
g_idx1
=
layer
.
w13_g_idx
,
w2_scale
=
layer
.
w2_scales
,
g_idx2
=
layer
.
w2_g_idx
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
).
to
(
orig_dtype
)
).
to
(
orig_dtype
)
vllm/model_executor/layers/quantization/ipex_quant.py
0 → 100644
View file @
6d2051cc
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.awq
import
AWQLinearMethod
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.platforms
import
current_platform
class
IPEXConfig
(
QuantizationConfig
):
"""INT8 quantization config class using IPEX for the CPU backend,
including AWQ.
"""
IPEX_QUANT_METHOD_MAP
=
{
"awq"
:
1
,
"gptq"
:
2
,
}
def
__init__
(
self
,
method
:
str
,
weight_bits
:
int
,
group_size
:
int
,
)
->
None
:
self
.
method
=
method
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
pack_factor
=
32
//
self
.
weight_bits
if
self
.
weight_bits
not
in
[
4
]:
raise
ValueError
(
f
"IPEX quantization supports weight bits [4], "
f
"but got
{
self
.
weight_bits
}
."
)
if
self
.
method
==
"awq"
:
self
.
quant_method
=
IPEXAWQLinearMethod
else
:
raise
ValueError
(
f
"IPEX quantization supports [awq], "
f
"but got
{
self
.
method
}
."
)
def
__repr__
(
self
)
->
str
:
return
(
f
"IPEXConfig(method=
{
self
.
method
}
"
f
"weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
"
)
def
get_ipex_quant_method_id
(
self
)
->
int
:
return
IPEXConfig
.
IPEX_QUANT_METHOD_MAP
[
self
.
method
]
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"ipex"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
-
1
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[
"quant_config.json"
,
"quantize_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"IPEXConfig"
:
method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
]).
lower
()
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"w_bit"
,
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"q_group_size"
,
"group_size"
])
return
cls
(
method
,
weight_bits
,
group_size
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
if
not
current_platform
.
is_cpu
():
return
None
quant_method
=
hf_quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
if
quant_method
in
[
"awq"
]:
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"LinearMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
self
.
quant_method
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
if
self
.
method
==
"awq"
:
return
[
"gelu"
,
"gelu_fast"
,
"gelu_new"
,
"gelu_pytorch_tanh"
]
else
:
return
[]
class
IPEXAWQLinearMethod
(
AWQLinearMethod
):
"""AWQ linear method using IPEX for the CPU backend.
"""
def
__init__
(
self
,
quant_config
:
IPEXConfig
):
self
.
quant_config
=
quant_config
# type: ignore
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
=
layer
)
bias
=
layer
.
bias
if
not
layer
.
skip_bias_add
else
None
try
:
import
intel_extension_for_pytorch
as
ipex
if
ipex
.
__version__
<
"2.4.0"
:
raise
ImportError
(
"intel_extension_for_pytorch version is "
"wrong. Please install "
"intel_extension_for_pytorch>=2.4.0."
)
except
ImportError
as
err
:
raise
ImportError
(
"Please install "
"intel_extension_for_pytorch>=2.4.0 via "
"`pip install intel_extension_for_pytorch>=2.4.0`"
" to use IPEX-AWQ linear method."
)
from
err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# with better performance.
lowp_mode
=
ipex
.
quantization
.
WoqLowpMode
.
INT8
# The weight will be de-packed from INT4 to INT8.
weight_dtype
=
ipex
.
quantization
.
WoqWeightDtype
.
INT4
# The float activation will be quantized (dynamic, per-token) to INT8.
act_quant_mode
=
ipex
.
quantization
.
WoqActQuantMode
.
PER_BATCH
qconfig
=
ipex
.
quantization
.
get_weight_only_quant_qconfig_mapping
(
weight_dtype
=
weight_dtype
,
lowp_mode
=
lowp_mode
,
act_quant_mode
=
act_quant_mode
,
group_size
=
self
.
quant_config
.
group_size
,
)
layer
.
ipex_output_size
=
layer
.
qweight
.
size
(
1
)
*
self
.
quant_config
.
pack_factor
layer
.
ipex_qlinear
=
ipex
.
nn
.
modules
.
weight_only_quantization
.
\
WeightOnlyQuantizedLinear
.
from_weight
(
layer
.
qweight
,
layer
.
scales
,
layer
.
qzeros
,
layer
.
qweight
.
size
(
0
),
layer
.
ipex_output_size
,
qconfig
=
qconfig
,
bias
=
bias
,
group_size
=
self
.
quant_config
.
group_size
,
quant_method
=
self
.
quant_config
.
get_ipex_quant_method_id
()
# type: ignore
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
layer
.
ipex_qlinear
(
reshaped_x
)
return
out
.
reshape
(
x
.
shape
[:
-
1
]
+
(
layer
.
ipex_output_size
,
))
vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py
View file @
6d2051cc
...
@@ -42,6 +42,10 @@ class MPLinearKernel(ABC):
...
@@ -42,6 +42,10 @@ class MPLinearKernel(ABC):
self
.
config
=
c
self
.
config
=
c
self
.
w_q_name
=
w_q_param_name
self
.
w_q_name
=
w_q_param_name
self
.
w_s_name
=
w_s_param_name
self
.
w_s_name
=
w_s_param_name
if
c
.
zero_points
:
assert
w_zp_param_name
is
not
None
if
c
.
has_g_idx
:
assert
w_gidx_param_name
is
not
None
self
.
w_zp_name
=
w_zp_param_name
self
.
w_zp_name
=
w_zp_param_name
self
.
w_gidx_name
=
w_gidx_param_name
self
.
w_gidx_name
=
w_gidx_param_name
...
...
vllm/model_executor/layers/quantization/kernels/__init__.py
View file @
6d2051cc
import
os
from
typing
import
List
,
Optional
,
Type
from
typing
import
List
,
Optional
,
Type
import
vllm.envs
as
envs
from
vllm.model_executor.layers.quantization.kernels.exllama
import
(
ExllamaLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.machete
import
(
from
vllm.model_executor.layers.quantization.kernels.machete
import
(
MacheteLinearKernel
)
MacheteLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.marlin
import
(
from
vllm.model_executor.layers.quantization.kernels.marlin
import
(
...
@@ -13,6 +15,7 @@ from vllm.platforms import current_platform
...
@@ -13,6 +15,7 @@ from vllm.platforms import current_platform
_POSSIBLE_KERNELS
:
List
[
Type
[
MPLinearKernel
]]
=
[
_POSSIBLE_KERNELS
:
List
[
Type
[
MPLinearKernel
]]
=
[
MacheteLinearKernel
,
MacheteLinearKernel
,
MarlinLinearKernel
,
MarlinLinearKernel
,
ExllamaLinearKernel
,
]
]
...
@@ -45,8 +48,7 @@ def choose_mp_linear_kernel(
...
@@ -45,8 +48,7 @@ def choose_mp_linear_kernel(
failure_reasons
=
[]
failure_reasons
=
[]
for
kernel
in
_POSSIBLE_KERNELS
:
for
kernel
in
_POSSIBLE_KERNELS
:
if
kernel
.
__name__
in
os
.
environ
.
get
(
"VLLM_DISABLED_KERNELS"
,
""
)
\
if
kernel
.
__name__
in
envs
.
VLLM_DISABLED_KERNELS
:
.
split
(
","
):
failure_reasons
.
append
(
failure_reasons
.
append
(
f
'
{
kernel
.
__name__
}
disabled by environment variable'
)
f
'
{
kernel
.
__name__
}
disabled by environment variable'
)
continue
continue
...
...
vllm/model_executor/layers/quantization/kernels/exllama.py
0 → 100644
View file @
6d2051cc
from
typing
import
Optional
,
Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_quantized_values_into_int32
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
from
vllm.scalar_type
import
scalar_types
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
class
ExllamaLinearKernel
(
MPLinearKernel
):
SUPPORTED_QUANT_TYPES
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
# currently untested so not added to the list
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
60
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
c
.
has_g_idx
and
\
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
]:
return
False
,
"Act reordering currently not supported by Exllama, "
\
"when the input features are partitioned across "
\
"devices"
if
c
.
partition_weight_shape
[
1
]
%
(
32
//
c
.
weight_type
.
size_bits
)
!=
0
:
return
False
,
"Output features must be a multiple of the pack "
\
"factor (32 / num_bits) so that we can correctly "
\
"pack the zero points"
if
c
.
act_type
!=
torch
.
float16
:
return
False
,
"Exllama only supports float16 activations"
if
c
.
weight_type
not
in
cls
.
SUPPORTED_QUANT_TYPES
:
return
False
,
f
"Quant type (
{
c
.
weight_type
}
) not supported by "
\
"Exllama, supported types are: "
\
f
"
{
cls
.
SUPPORTED_QUANT_TYPES
}
"
if
c
.
full_weight_shape
[
0
]
%
c
.
group_size
!=
0
:
return
False
,
f
"Group size (
{
c
.
group_size
}
) does not evenly divide"
\
" the number of input features "
\
f
"(
{
c
.
full_weight_shape
[
0
]
}
)"
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
c
=
self
.
config
# For Exllama, we need to set a zero-point tensor if there is not one
if
not
c
.
zero_points
:
self
.
w_zp_name
=
"qzeros"
device
=
getattr
(
layer
,
self
.
w_q_name
).
device
groups
=
c
.
partition_weight_shape
[
0
]
//
c
.
group_size
out_features
=
c
.
partition_weight_shape
[
1
]
if
c
.
weight_type
.
has_bias
():
# if the type has a bias we have to create a zeros tensor that
# contains the bias values repeated for each group (-1 due to
# a bug in the original GPTQ checkpoint format leading to
# exllama kernel adding 1 to the zero points during inference)
# Documentation of the bug can be found here:
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
zeros
=
torch
.
full
((
groups
,
out_features
),
c
.
weight_type
.
bias
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
else
:
raise
NotImplementedError
(
"A 0 zero-point is not supported by Exllama due to "
"a bug in the original GPTQ checkpoint format leading to "
"exllama kernel adding 1 to the zero points during "
"inference"
)
zeros
=
pack_quantized_values_into_int32
(
zeros
,
c
.
weight_type
,
packed_dim
=
1
)
setattr
(
layer
,
self
.
w_zp_name
,
torch
.
nn
.
Parameter
(
zeros
,
requires_grad
=
False
))
if
c
.
has_g_idx
:
def
transform_w_g_idx
(
x
):
# Exllama wants the permutation array instead of the group
# indices
return
torch
.
argsort
(
x
).
to
(
torch
.
int
)
self
.
_transform_param
(
layer
,
self
.
w_gidx_name
,
transform_w_g_idx
)
else
:
self
.
w_gidx_name
=
"g_idx"
empty_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
0
,
),
dtype
=
torch
.
int
,
device
=
device
),
requires_grad
=
False
)
setattr
(
layer
,
self
.
w_gidx_name
,
empty_g_idx
)
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
assert
self
.
w_gidx_name
is
not
None
g_idx
=
getattr
(
layer
,
self
.
w_gidx_name
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
x_cont
=
x
.
data
.
contiguous
()
ops
.
gptq_shuffle
(
x_cont
,
g_idx
,
c
.
weight_type
.
size_bits
)
return
x_cont
def
transform_w_s
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
)
x
.
data
=
x
.
data
.
contiguous
()
return
x
.
to
(
dtype
=
c
.
act_type
)
# Repack weights and scales for Machete
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
c
=
self
.
config
x_2d
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
c
.
partition_weight_shape
[
1
],
)
w_q
,
w_s
,
w_zp
,
w_g_idx
=
self
.
_get_weight_params
(
layer
)
assert
w_zp
is
not
None
,
"Zero points are required by Exllama"
assert
w_g_idx
is
not
None
,
"Group index is required by Exllama"
output
=
ops
.
gptq_gemm
(
x_2d
,
w_q
,
w_zp
,
w_s
,
w_g_idx
,
True
,
c
.
weight_type
.
size_bits
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/kernels/machete.py
View file @
6d2051cc
...
@@ -8,7 +8,7 @@ from vllm.model_executor.layers.quantization.utils.machete_utils import (
...
@@ -8,7 +8,7 @@ from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES
,
check_machete_supports_shape
,
MACHETE_SUPPORTED_GROUP_SIZES
,
check_machete_supports_shape
,
query_machete_supported_quant_types
)
query_machete_supported_quant_types
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_
weight
s_into_int32
,
unpack_
weight
s_into_int32
)
pack_
quantized_value
s_into_int32
,
unpack_
quantized_value
s_into_int32
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
permute_param_layout_
)
...
@@ -71,13 +71,13 @@ class MacheteLinearKernel(MPLinearKernel):
...
@@ -71,13 +71,13 @@ class MacheteLinearKernel(MPLinearKernel):
assert
isinstance
(
x
,
BasevLLMParameter
)
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
if
c
.
has_g_idx
:
if
c
.
has_g_idx
:
x_unpacked
=
unpack_
weight
s_into_int32
(
x
.
data
,
x_unpacked
=
unpack_
quantized_value
s_into_int32
(
x
.
data
,
c
.
weight_type
,
c
.
weight_type
,
packed_dim
=
0
)
packed_dim
=
0
)
x_perm
=
x_unpacked
[
perm
,
:]
x_perm
=
x_unpacked
[
perm
,
:]
x
.
data
=
pack_
weight
s_into_int32
(
x_perm
,
x
.
data
=
pack_
quantized_value
s_into_int32
(
x_perm
,
c
.
weight_type
,
c
.
weight_type
,
packed_dim
=
0
)
packed_dim
=
0
)
x
.
data
=
ops
.
machete_prepack_B
(
x
.
data
.
t
().
contiguous
().
t
(),
x
.
data
=
ops
.
machete_prepack_B
(
x
.
data
.
t
().
contiguous
().
t
(),
self
.
config
.
weight_type
)
self
.
config
.
weight_type
)
return
x
return
x
...
...
vllm/model_executor/layers/quantization/kernels/marlin.py
View file @
6d2051cc
...
@@ -38,10 +38,11 @@ class MarlinLinearKernel(MPLinearKernel):
...
@@ -38,10 +38,11 @@ class MarlinLinearKernel(MPLinearKernel):
"Marlin, supported group sizes are: "
\
"Marlin, supported group sizes are: "
\
f
"
{
MARLIN_SUPPORTED_GROUP_SIZES
}
"
f
"
{
MARLIN_SUPPORTED_GROUP_SIZES
}
"
return
check_marlin_supports_shape
(
c
.
partition_weight_shape
[
0
],
return
check_marlin_supports_shape
(
c
.
partition_weight_shape
[
1
],
c
.
partition_weight_shape
[
1
],
# out_features
c
.
full_weight_shape
[
1
],
c
.
partition_weight_shape
[
0
],
# in_features
c
.
group_size
)
c
.
full_weight_shape
[
0
],
# in_features
c
.
group_size
)
# note assumes that
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
6d2051cc
...
@@ -208,6 +208,7 @@ def marlin_moe_permute_scales(
...
@@ -208,6 +208,7 @@ def marlin_moe_permute_scales(
device
=
s
.
device
,
device
=
s
.
device
,
dtype
=
s
.
dtype
,
dtype
=
s
.
dtype
,
)
)
for
e
in
range
(
num_experts
):
for
e
in
range
(
num_experts
):
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
)
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
)
return
output
return
output
...
@@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
...
@@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return
marlin_zp
return
marlin_zp
def
moe_awq_to_marlin_zero_points
(
q_zp_packed
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
):
num_experts
=
q_zp_packed
.
shape
[
0
]
output
=
torch
.
empty
(
(
num_experts
,
q_zp_packed
.
shape
[
1
],
q_zp_packed
.
shape
[
2
]),
device
=
q_zp_packed
.
device
,
dtype
=
q_zp_packed
.
dtype
,
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
awq_to_marlin_zero_points
(
q_zp_packed
[
e
],
size_k
,
size_n
,
num_bits
)
return
output
def
apply_gptq_marlin_linear
(
def
apply_gptq_marlin_linear
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
6d2051cc
...
@@ -20,9 +20,9 @@ FUSED_LAYER_NAME_MAPPING = {
...
@@ -20,9 +20,9 @@ FUSED_LAYER_NAME_MAPPING = {
}
}
def
pack_
weight
s_into_int32
(
w_q
:
torch
.
Tensor
,
def
pack_
quantized_value
s_into_int32
(
w_q
:
torch
.
Tensor
,
wtype
:
ScalarType
,
wtype
:
ScalarType
,
packed_dim
:
int
=
0
):
packed_dim
:
int
=
0
):
# move dim to pack to the end
# move dim to pack to the end
perm
=
(
*
[
i
for
i
in
range
(
len
(
w_q
.
shape
))
if
i
!=
packed_dim
],
packed_dim
)
perm
=
(
*
[
i
for
i
in
range
(
len
(
w_q
.
shape
))
if
i
!=
packed_dim
],
packed_dim
)
inv_perm
=
tuple
(
perm
.
index
(
i
)
for
i
in
range
(
len
(
perm
)))
inv_perm
=
tuple
(
perm
.
index
(
i
)
for
i
in
range
(
len
(
perm
)))
...
@@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor,
...
@@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor,
return
res
.
permute
(
inv_perm
)
return
res
.
permute
(
inv_perm
)
def
unpack_
weight
s_into_int32
(
w_q
:
torch
.
Tensor
,
def
unpack_
quantized_value
s_into_int32
(
w_q
:
torch
.
Tensor
,
wtype
:
ScalarType
,
wtype
:
ScalarType
,
packed_dim
:
int
=
0
):
packed_dim
:
int
=
0
):
# move dim to pack to the end
# move dim to pack to the end
perm
=
(
*
[
i
for
i
in
range
(
len
(
w_q
.
shape
))
if
i
!=
packed_dim
],
packed_dim
)
perm
=
(
*
[
i
for
i
in
range
(
len
(
w_q
.
shape
))
if
i
!=
packed_dim
],
packed_dim
)
inv_perm
=
tuple
(
perm
.
index
(
i
)
for
i
in
range
(
len
(
perm
)))
inv_perm
=
tuple
(
perm
.
index
(
i
)
for
i
in
range
(
len
(
perm
)))
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
6d2051cc
...
@@ -159,7 +159,8 @@ def apply_fp8_linear(
...
@@ -159,7 +159,8 @@ def apply_fp8_linear(
# Making sure the dummy tensor is on the same device as the weight
# Making sure the dummy tensor is on the same device as the weight
global
TORCH_DEVICE_IDENTITY
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
:
if
(
TORCH_DEVICE_IDENTITY
is
not
None
and
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
):
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
# GEMM
# GEMM
...
@@ -191,13 +192,28 @@ def apply_int8_linear(
...
@@ -191,13 +192,28 @@ def apply_int8_linear(
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
input_zero_point
:
Optional
[
torch
.
Tensor
]
=
None
,
azp_adj
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
symmetric
=
azp_adj
is
None
x_q
,
x_scale
,
x_zp
=
ops
.
scaled_int8_quant
(
input
,
input_scale
,
input_zero_point
,
symmetric
=
symmetric
)
if
x_zp
is
not
None
:
return
ops
.
cutlass_scaled_mm_azp
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
azp_adj
=
azp_adj
,
azp
=
x_zp
,
bias
=
bias
)
return
ops
.
cutlass_scaled_mm
(
x_q
,
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
weight
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
...
...
Prev
1
…
16
17
18
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment