Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
92545504
Commit
92545504
authored
Mar 14, 2025
by
gaoqiong
Browse files
增加deepseek v3 block-int8量化支持
parent
146eb9d3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
998 additions
and
13 deletions
+998
-13
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+41
-8
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+3
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+5
-2
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+4
-1
vllm/model_executor/layers/quantization/blockwise_int8.py
vllm/model_executor/layers/quantization/blockwise_int8.py
+424
-0
vllm/model_executor/layers/quantization/utils/int8_utils.py
vllm/model_executor/layers/quantization/utils/int8_utils.py
+517
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+3
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-1
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
92545504
...
...
@@ -14,6 +14,9 @@ from vllm import _custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
per_token_group_quant_int8
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -265,6 +268,7 @@ def fused_moe_kernel(
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
...
...
@@ -346,7 +350,7 @@ def fused_moe_kernel(
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
...
...
@@ -376,7 +380,7 @@ def fused_moe_kernel(
# We accumulate along the K dimension.
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
...
...
@@ -402,7 +406,7 @@ def fused_moe_kernel(
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
else
:
...
...
@@ -709,6 +713,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
...
...
@@ -727,6 +732,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_int8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_int8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
...
...
@@ -826,6 +844,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
top_k
=
top_k
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
...
...
@@ -1060,9 +1079,12 @@ def grouped_topk(hidden_states: torch.Tensor,
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
use_fp8_w8a8
:
Optional
[
bool
]
=
False
,
use_int8_w8a8
:
Optional
[
bool
]
=
False
):
if
use_fp8_w8a8
:
return
"fp8_w8a8"
elif
use_int8_w8a8
:
return
"int8_w8a8"
elif
use_int8_w8a16
:
return
"int8_w8a16"
elif
use_int4_w4a16
:
...
...
@@ -1080,6 +1102,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1094,7 +1117,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
start_expert
=
start_expert
,
end_expert
=
end_expert
)
...
...
@@ -1107,6 +1130,7 @@ def inplace_fused_experts_fake(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1138,6 +1162,7 @@ def outplace_fused_experts(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1152,7 +1177,7 @@ def outplace_fused_experts(
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
False
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
...
...
@@ -1166,6 +1191,7 @@ def outplace_fused_experts_fake(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1197,6 +1223,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1213,7 +1240,7 @@ def fused_experts(hidden_states: torch.Tensor,
if
inplace
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
...
...
@@ -1224,7 +1251,7 @@ def fused_experts(hidden_states: torch.Tensor,
return
hidden_states
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
...
...
@@ -1239,6 +1266,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1279,6 +1307,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
...
...
@@ -1369,6 +1398,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
,
...
...
@@ -1393,6 +1423,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
,
...
...
@@ -1416,6 +1447,7 @@ def fused_moe(
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1492,6 +1524,7 @@ def fused_moe(
topk_ids
,
inplace
=
inplace
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
w1_scale
=
w1_scale
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
92545504
...
...
@@ -363,6 +363,9 @@ class FusedMoE(torch.nn.Module):
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsWNA16MoEMethod"
):
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"BlockInt8MoEMethod"
)):
moe_quant_params
[
"intermediate_size"
]
=
self
.
intermediate_size_per_partition
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
...
...
vllm/model_executor/layers/linear.py
View file @
92545504
...
...
@@ -37,7 +37,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
,
"IPEXAWQLinearMethod"
,
"IPEXGPTQLinearMethod"
,
"HQQMarlinMethod"
,
"QuarkLinearMethod"
"HQQMarlinMethod"
,
"QuarkLinearMethod"
,
"BlockInt8LinearMethod"
,
]
...
...
@@ -664,9 +664,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
isinstance
(
param
,
BlockQuantScaleParameter
):
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
)
from
vllm.model_executor.layers.quantization.blockwise_int8
import
(
BlockInt8LinearMethod
,
BlockInt8MoEMethod
)
assert
self
.
quant_method
is
not
None
assert
isinstance
(
self
.
quant_method
,
(
Fp8LinearMethod
,
Fp8MoEMethod
))
(
Fp8LinearMethod
,
Fp8MoEMethod
,
BlockInt8LinearMethod
,
BlockInt8MoEMethod
))
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
assert
weight_block_size
is
not
None
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
92545504
...
...
@@ -29,7 +29,8 @@ QUANTIZATION_METHODS: List[str] = [
"neuron_quant"
,
"ipex"
,
"quark"
,
"moe_wna16"
"moe_wna16"
,
"blockwise_int8"
]
# The customized quantization methods which will be added to this dict.
...
...
@@ -101,6 +102,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from
.neuron_quant
import
NeuronQuantConfig
from
.qqq
import
QQQConfig
from
.tpu_int8
import
Int8TpuConfig
from
.blockwise_int8
import
BlockInt8Config
method_to_config
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
...
...
@@ -127,6 +129,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"ipex"
:
IPEXConfig
,
"quark"
:
QuarkConfig
,
"moe_wna16"
:
MoeWNA16Config
,
"blockwise_int8"
:
BlockInt8Config
,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config
.
update
(
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
)
...
...
vllm/model_executor/layers/quantization/blockwise_int8.py
0 → 100755
View file @
92545504
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/3730
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
apply_w8a8_block_int8_linear
)
from
vllm.model_executor.utils
import
set_weight_attrs
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
logger
=
logging
.
getLogger
(
__name__
)
class
BlockInt8Config
(
QuantizationConfig
):
"""Config class for INT8."""
def
__init__
(
self
,
is_checkpoint_int8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
ignored_layers
:
Optional
[
List
[
str
]]
=
None
,
weight_block_size
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
self
.
is_checkpoint_int8_serialized
=
is_checkpoint_int8_serialized
if
is_checkpoint_int8_serialized
:
logger
.
warning
(
"Detected int8 checkpoint. Please note that the "
"format is experimental and subject to change."
)
if
activation_scheme
not
in
ACTIVATION_SCHEMES
:
raise
ValueError
(
"Unsupported activation scheme"
f
"
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
self
.
ignored_layers
=
ignored_layers
or
[]
if
weight_block_size
is
not
None
:
if
not
is_checkpoint_int8_serialized
:
raise
ValueError
(
f
"The block-wise quantization only supports "
"int8-serialized checkpoint for now."
)
if
len
(
weight_block_size
)
!=
2
:
raise
ValueError
(
f
"The quantization block size of weight must have 2 "
"dimensions, but got {len(weight_block_size)} dimensions."
)
if
activation_scheme
!=
"dynamic"
:
raise
ValueError
(
f
"The block-wise quantization only supports dynamic "
"activation scheme for now, but got "
"{activation_scheme} activation scheme."
)
self
.
weight_block_size
=
weight_block_size
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"blockwise_int8"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"BlockInt8Config"
:
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_int8_serialized
=
"int8"
in
quant_method
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
ignored_layers
=
cls
.
get_from_keys_or
(
config
,
[
"ignored_layers"
],
None
)
weight_block_size
=
cls
.
get_from_keys_or
(
config
,
[
"weight_block_size"
],
None
)
return
cls
(
is_checkpoint_int8_serialized
=
is_checkpoint_int8_serialized
,
activation_scheme
=
activation_scheme
,
ignored_layers
=
ignored_layers
,
weight_block_size
=
weight_block_size
,
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
return
UnquantizedLinearMethod
()
return
BlockInt8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
BlockInt8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
BlockInt8LinearMethod
(
LinearMethodBase
):
"""Linear method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic activation scale.
Limitations:
Only support block-wise int8 quantization and int8 checkpoint
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
BlockInt8Config
):
self
.
quant_config
=
quant_config
assert
self
.
quant_config
.
weight_block_size
is
not
None
assert
self
.
quant_config
.
is_checkpoint_int8_serialized
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
Optional
[
List
[
int
]],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# assert output_partition_sizes is not None, (
# "output_partition_sizes must be provided for quantization")
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
tp_size
=
get_tensor_model_parallel_world_size
()
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
)
# Required by row parallel
if
tp_size
>
1
and
input_size
//
input_size_per_partition
==
tp_size
:
if
input_size_per_partition
%
block_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"weight quantization block_k =
{
block_k
}
."
)
# Required by collum parallel or enabling merged weights
if
(
tp_size
>
1
and
output_size
//
output_size_per_partition
==
tp_size
)
or
len
(
output_partition_sizes
)
>
1
:
for
output_partition_size
in
output_partition_sizes
:
if
output_partition_size
%
block_n
!=
0
:
raise
ValueError
(
f
"Weight output_partition_size = "
f
"
{
output_partition_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
weight_dtype
=
(
torch
.
int8
if
self
.
quant_config
.
is_checkpoint_int8_serialized
else
params_dtype
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
weight_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
# WEIGHT SCALE
scale
=
BlockQuantScaleParameter
(
data
=
torch
.
empty
(
(
output_size_per_partition
+
block_n
-
1
)
//
block_n
,
(
input_size_per_partition
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale_inv"
,
scale
)
# INPUT ACTIVATION SCALE
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
# Use torch Parameter to avoid cuda graph capturing issue
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
apply_w8a8_block_int8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
block_size
=
self
.
quant_config
.
weight_block_size
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
None
,
bias
=
bias
,
)
class
BlockInt8MoEMethod
:
"""MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic activation scale.
Limitations:
Only support block-wise int8 quantization and int8 checkpoint
Args:
quant_config: The quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
assert
self
.
quant_config
.
weight_block_size
is
not
None
assert
self
.
quant_config
.
is_checkpoint_int8_serialized
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
from
vllm.model_executor.layers.fused_moe
import
FusedMoeWeightScaleSupported
if
self
.
quant_config
.
is_checkpoint_int8_serialized
:
params_dtype
=
torch
.
int8
tp_size
=
get_tensor_model_parallel_world_size
()
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by collum parallel or enabling merged weights
if
intermediate_size
%
block_n
!=
0
:
raise
ValueError
(
f
"The output_size of gate's and up's weight = "
f
"
{
intermediate_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
if
tp_size
>
1
:
# Required by row parallel
if
intermediate_size
%
block_k
!=
0
:
raise
ValueError
(
f
"The input_size of down's weight = "
f
"
{
intermediate_size
}
is not divisible by "
f
"weight quantization block_k =
{
block_k
}
."
)
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
((
intermediate_size
+
block_n
-
1
)
//
block_n
),
(
hidden_size
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
(
hidden_size
+
block_n
-
1
)
//
block_n
,
(
intermediate_size
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale_inv"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale_inv"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
BLOCK
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
return
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
#print("===========fused_experts========================")
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
# Expert fusion with INT8 quantization
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
use_nn_moe
=
use_nn_moe
,
moe_ep_size
=
moe_ep_size
,
start_expert
=
start_expert
,
end_expert
=
end_expert
)
vllm/model_executor/layers/quantization/utils/int8_utils.py
0 → 100755
View file @
92545504
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/3730
import
functools
import
json
import
logging
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
# from sglang.srt.utils import get_device_name
from
vllm.platforms
import
current_platform
logger
=
logging
.
getLogger
(
__name__
)
@
triton
.
jit
def
_per_token_quant_int8
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
BLOCK
:
tl
.
constexpr
,
):
row_id
=
tl
.
program_id
(
0
)
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
tl
.
extra
.
cuda
.
libdevice
.
round
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
def
per_token_quant_int8
(
x
):
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
assert
x
.
is_contiguous
()
_per_token_quant_int8
[(
M
,)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
@
triton
.
jit
def
_per_token_group_quant_int8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
# Stride of input
y_stride
,
# Collums of input
N
,
# Avoid to divide zero
eps
,
# Information for int8
int8_min
,
int8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform
per-token-group quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
y_stride
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
int8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
int8_min
,
int8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_int8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
torch
.
int8
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
iinfo
=
torch
.
iinfo
(
dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_int8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
int8_min
=
int8_min
,
int8_max
=
int8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
@
triton
.
jit
def
_w8a8_block_int8_matmul
(
# Pointers to inputs and output
A
,
B
,
C
,
As
,
Bs
,
# Shape for matmul
M
,
N
,
K
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Stride for inputs and output
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_As_m
,
stride_As_k
,
stride_Bs_k
,
stride_Bs_n
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
As_ptrs
=
As
+
offs_am
*
stride_As_m
offs_bsn
=
offs_bn
//
group_n
Bs_ptrs
=
Bs
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
accumulator
+=
tl
.
dot
(
a
,
b
).
to
(
tl
.
float32
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
c
=
accumulator
.
to
(
tl
.
float16
)
else
:
c
=
accumulator
.
to
(
tl
.
float32
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
functools
.
lru_cache
def
get_w8a8_block_int8_configs
(
N
:
int
,
K
:
int
,
block_n
:
int
,
block_k
:
int
)
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
json_file_name
=
f
"N=
{
N
}
,K=
{
K
}
,device_name=
{
device_name
}
,dtype=int8_w8a8,block_shape=[
{
block_n
}
,
{
block_k
}
].json"
# noqa: E501
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for W8A8 Block INT8 kernel."
,
config_file_path
,
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default
# configuration
logger
.
warning
(
(
"Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
),
config_file_path
,
)
return
None
def
w8a8_block_int8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
List
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
"""matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
and
A
.
is_contiguous
()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
configs
=
get_w8a8_block_int8_configs
(
N
,
K
,
block_size
[
0
],
block_size
[
1
])
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
config
=
{
"BLOCK_SIZE_M"
:
32
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
_w8a8_block_int8_matmul
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
-
2
),
As
.
stride
(
-
1
),
Bs
.
stride
(
1
),
Bs
.
stride
(
0
),
**
config
,
)
return
C
def
native_w8a8_block_int8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
float16
):
"""matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A
=
A
.
to
(
torch
.
float32
)
B
=
B
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
(
A
.
shape
[
-
1
]
+
block_k
-
1
)
//
block_k
==
As
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
N
,
K
=
B
.
shape
origin_C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
A
=
A
.
reshape
(
M
,
A
.
shape
[
-
1
])
As
=
As
.
reshape
(
M
,
As
.
shape
[
-
1
])
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
Bs
.
shape
[
0
]
assert
k_tiles
==
Bs
.
shape
[
1
]
C_shape
=
(
M
,
N
)
C
=
torch
.
zeros
(
C_shape
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
A_tiles
=
[
A
[:,
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
)]
for
i
in
range
(
k_tiles
)]
B_tiles
=
[
[
B
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)
]
C_tiles
=
[
C
[:,
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
)]
for
j
in
range
(
n_tiles
)]
As_tiles
=
[
As
[:,
i
:
i
+
1
]
for
i
in
range
(
k_tiles
)]
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
a
=
A_tiles
[
i
]
b
=
B_tiles
[
j
][
i
]
c
=
C_tiles
[
j
]
s
=
As_tiles
[
i
]
*
Bs
[
j
][
i
]
c
[:,
:]
+=
torch
.
matmul
(
a
,
b
.
t
())
*
s
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
def
apply_w8a8_block_int8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
block_size
:
List
[
int
],
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
input_scale
is
None
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
=
per_token_group_quant_int8
(
input_2d
,
block_size
[
1
])
output
=
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
input
.
dtype
)
# output = native_w8a8_block_int8_matmul(
# q_input, weight, x_scale, weight_scale, block_size,
# output_dtype=input.dtype
# )
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
def
input_to_int8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
int8
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""This function quantizes input values to
int8 values with tensor-wise quantization.
"""
iinfo
=
torch
.
iinfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
int8_min
,
int8_max
=
iinfo
.
min
,
iinfo
.
max
scale
=
int8_max
/
amax
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
int8_min
,
max
=
int8_max
)
return
x_scl_sat
.
to
(
dtype
).
contiguous
(),
scale
.
float
().
reciprocal
()
def
block_dequant
(
x_q_block
:
torch
.
Tensor
,
x_s
:
torch
.
Tensor
,
block_size
:
List
[
int
],
)
->
torch
.
Tensor
:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are dequantized tensor.
"""
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n
,
k
=
x_q_block
.
shape
n_tiles
=
(
n
+
block_n
-
1
)
//
block_n
k_tiles
=
(
k
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
x_s
.
shape
[
0
]
assert
k_tiles
==
x_s
.
shape
[
1
]
x_dq_block
=
x_q_block
.
to
(
torch
.
float32
)
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
x_dq_block
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
n
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
k
),
]
*=
x_s
[
j
][
i
]
return
x_dq_block
\ No newline at end of file
vllm/model_executor/models/deepseek_v2.py
View file @
92545504
...
...
@@ -59,7 +59,9 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
block_dequant
as
int8_block_dequant
,
)
class
DeepseekV2MLP
(
nn
.
Module
):
...
...
vllm/platforms/rocm.py
View file @
92545504
...
...
@@ -72,7 +72,7 @@ class RocmPlatform(Platform):
supported_quantization
:
list
[
str
]
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"quark"
,
"moe_wna16"
"fbgemm_fp8"
,
"gguf"
,
"quark"
,
"moe_wna16"
,
"blockwise_int8"
]
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment