Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1d0c9d6b
Unverified
Commit
1d0c9d6b
authored
May 06, 2025
by
Jinzhen Lin
Committed by
GitHub
May 05, 2025
Browse files
[Kernel] some optimizations for dense marlin and moe marlin (#16850)
Signed-off-by:
Jinzhen Lin
<
linjinzhen@hotmail.com
>
parent
f62cad64
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
378 additions
and
112 deletions
+378
-112
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+57
-42
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+11
-8
vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
...tor/layers/quantization/kernels/mixed_precision/marlin.py
+2
-4
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+60
-16
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+237
-42
vllm/scalar_type.py
vllm/scalar_type.py
+11
-0
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
1d0c9d6b
...
@@ -21,19 +21,21 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -21,19 +21,21 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
prepare_moe_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
all_close_1d
,
c
onvert_to_channelwise
,
Fp8LinearOp
,
all_close_1d
,
c
utlass_block_fp8_supported
,
cutlass_
block_
fp8_supported
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
per_tensor_dequantize
,
requantize_with_max_scale
)
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
@@ -181,10 +183,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -181,10 +183,6 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
use_marlin
=
False
self
.
use_marlin
=
False
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
if
self
.
block_quant
:
# Marlin doesn't support block-wise fp8
self
.
use_marlin
=
False
self
.
fp8_linear
=
Fp8LinearOp
(
self
.
fp8_linear
=
Fp8LinearOp
(
# Default to using per_token quantization if cutlass is supported
# Default to using per_token quantization if cutlass is supported
use_per_token_if_dynamic
=
cutlass_fp8_supported
())
use_per_token_if_dynamic
=
cutlass_fp8_supported
())
...
@@ -203,10 +201,16 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -203,10 +201,16 @@ class Fp8LinearMethod(LinearMethodBase):
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
if
self
.
block_quant
:
if
self
.
block_quant
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
quant_config
.
weight_block_size
is
not
None
assert
self
.
quant_config
.
weight_block_size
is
not
None
layer
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
block_n
,
block_k
=
(
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
self
.
quant_config
.
weight_block_size
[
1
],
...
@@ -229,12 +233,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -229,12 +233,6 @@ class Fp8LinearMethod(LinearMethodBase):
f
"
{
output_partition_size
}
is not divisible by "
f
"
{
output_partition_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
f
"weight quantization block_n =
{
block_n
}
."
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
# WEIGHT
weight_dtype
=
(
torch
.
float8_e4m3fn
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
...
@@ -303,9 +301,11 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -303,9 +301,11 @@ class Fp8LinearMethod(LinearMethodBase):
return
weight
return
weight
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
size_k_first
=
True
# TODO(rob): refactor block quant into separate class.
# TODO(rob): refactor block quant into separate class.
if
self
.
block_quant
:
if
self
.
block_quant
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
size_k_first
=
False
if
current_platform
.
is_fp8_fnuz
():
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale_inv
,
_
=
\
weight
,
weight_scale_inv
,
_
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
...
@@ -321,21 +321,12 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -321,21 +321,12 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
Parameter
(
weight_scale_inv
,
layer
.
weight_scale_inv
=
Parameter
(
weight_scale_inv
,
requires_grad
=
False
)
requires_grad
=
False
)
return
# If checkpoint not serialized fp8, quantize the weights.
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
el
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
scale
=
None
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if
self
.
use_marlin
:
assert
weight_scale
.
numel
()
==
1
weight_scale
=
convert_to_channelwise
(
weight_scale
.
expand
(
len
(
layer
.
logical_widths
)),
layer
.
logical_widths
)
# Update the layer with the new values.
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
...
@@ -349,20 +340,14 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -349,20 +340,14 @@ class Fp8LinearMethod(LinearMethodBase):
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
requires_grad
=
False
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
weight
=
layer
.
weight
if
self
.
use_marlin
:
weight_scale
=
layer
.
weight_scale
weight
=
layer
.
weight
weight_scale
=
convert_to_channelwise
(
layer
.
weight_scale
,
layer
.
logical_widths
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
# requantize the logical shards as a single weight.
else
:
if
not
self
.
use_marlin
:
# Dequant -> Quant with max scale so we can run per tensor.
# Dequant -> Quant with max scale so we can run per tensor.
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
if
current_platform
.
is_fp8_fnuz
():
if
current_platform
.
is_fp8_fnuz
():
weight
,
weight_scale
,
input_scale
=
\
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
...
@@ -388,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -388,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
requires_grad
=
False
)
requires_grad
=
False
)
if
self
.
use_marlin
:
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
prepare_fp8_layer_for_marlin
(
layer
,
size_k_first
)
# Activations not quantized for marlin.
# Activations not quantized for marlin.
del
layer
.
input_scale
del
layer
.
input_scale
...
@@ -444,6 +429,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -444,6 +429,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if
current_platform
.
is_rocm
():
self
.
use_marlin
=
False
# Check for DeepGemm support.
# Check for DeepGemm support.
self
.
allow_deep_gemm
=
False
self
.
allow_deep_gemm
=
False
if
envs
.
VLLM_USE_DEEP_GEMM
:
if
envs
.
VLLM_USE_DEEP_GEMM
:
...
@@ -461,10 +454,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -461,10 +454,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
params_dtype
=
torch
.
float8_e4m3fn
if
self
.
block_quant
:
if
self
.
block_quant
:
assert
self
.
quant_config
.
weight_block_size
is
not
None
assert
self
.
quant_config
.
weight_block_size
is
not
None
layer
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
block_n
,
block_k
=
(
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
0
],
...
@@ -630,10 +630,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -630,10 +630,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_weight_scale_inv
=
\
layer
.
w2_weight_scale_inv
=
\
dg
.
get_col_major_tma_aligned_tensor
(
layer
.
w2_weight_scale_inv
).
contiguous
()
dg
.
get_col_major_tma_aligned_tensor
(
layer
.
w2_weight_scale_inv
).
contiguous
()
return
# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
el
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
fp8_dtype
=
current_platform
.
fp8_dtype
()
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
dtype
=
fp8_dtype
)
...
@@ -677,8 +675,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -677,8 +675,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
requires_grad
=
False
)
return
# If checkpoint is fp8, we need to handle that the
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
# scale for w13 per expert.
...
@@ -766,7 +762,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -766,7 +762,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
requires_grad
=
False
)
return
if
self
.
use_marlin
:
prepare_moe_fp8_layer_for_marlin
(
layer
,
False
)
# Activations not quantized for marlin.
del
layer
.
w13_input_scale
del
layer
.
w2_input_scale
def
apply
(
def
apply
(
self
,
self
,
...
@@ -801,6 +802,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -801,6 +802,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
)
)
if
self
.
use_marlin
:
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_ids
,
quant_type_id
=
scalar_types
.
float8_e4m3fn
.
id
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
return
fused_experts
(
return
fused_experts
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
1d0c9d6b
...
@@ -21,8 +21,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
...
@@ -21,8 +21,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method
)
get_linear_quant_method
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supported
,
check_moe_marlin_supports_layer
,
check_marlin_supported
,
check_moe_marlin_supports_layer
,
marlin_m
oe_permute_scales
,
marlin_repeat_scales_on_all_rank
s
,
marlin_m
ake_workspace_new
,
marlin_moe_permute_scale
s
,
verify_marlin_supported
)
marlin_repeat_scales_on_all_ranks
,
verify_marlin_supported
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedColumnParameter
,
...
@@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
if
self
.
quant_config
.
quant_type
.
size_bits
==
4
:
self
.
quant_type
=
scalar_types
.
uint4b8
elif
self
.
quant_config
.
quant_type
.
size_bits
==
8
:
self
.
quant_type
=
scalar_types
.
uint8b128
else
:
raise
ValueError
(
"GPTQMarlinMoEMethod only supports int4 and int8 now."
)
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -498,11 +505,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -498,11 +505,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
device
=
layer
.
w13_qweight
.
device
device
=
layer
.
w13_qweight
.
device
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
layer
.
workspace
=
marlin_make_workspace_new
(
device
,
4
)
layer
.
workspace
=
torch
.
zeros
((
sms
*
4
,
),
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
...
@@ -633,12 +636,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -633,12 +636,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits
,
router_logits
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
quant_type_id
=
self
.
quant_type
.
id
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
g_idx1
=
layer
.
w13_g_idx
,
g_idx1
=
layer
.
w13_g_idx
,
g_idx2
=
layer
.
w2_g_idx
,
g_idx2
=
layer
.
w2_g_idx
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
workspace
=
layer
.
workspace
,
workspace
=
layer
.
workspace
,
is_k_full
=
self
.
is_k_full
)
is_k_full
=
self
.
is_k_full
)
vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
View file @
1d0c9d6b
...
@@ -8,7 +8,7 @@ from vllm import _custom_ops as ops
...
@@ -8,7 +8,7 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MARLIN_SUPPORTED_GROUP_SIZES
,
apply_gptq_marlin_linear
,
MARLIN_SUPPORTED_GROUP_SIZES
,
apply_gptq_marlin_linear
,
check_marlin_supports_shape
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
check_marlin_supports_shape
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_sort_g_idx
,
marlin_make_workspace
_new
,
marlin_permute_scales
,
marlin_sort_g_idx
,
marlin_zero_points
,
query_marlin_supported_quant_types
,
unpack_cols
)
marlin_zero_points
,
query_marlin_supported_quant_types
,
unpack_cols
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
permute_param_layout_
)
...
@@ -53,8 +53,7 @@ class MarlinLinearKernel(MPLinearKernel):
...
@@ -53,8 +53,7 @@ class MarlinLinearKernel(MPLinearKernel):
self
.
is_k_full
=
marlin_is_k_full
(
c
.
has_g_idx
,
row_parallel
)
self
.
is_k_full
=
marlin_is_k_full
(
c
.
has_g_idx
,
row_parallel
)
# Allocate marlin workspace.
# Allocate marlin workspace.
self
.
workspace
=
marlin_make_workspace
(
c
.
partition_weight_shape
[
1
],
self
.
workspace
=
marlin_make_workspace_new
(
device
)
device
)
# Default names since marlin requires empty parameters for these,
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
# TODO: remove this requirement from marlin (allow optional tensors)
...
@@ -127,6 +126,5 @@ class MarlinLinearKernel(MPLinearKernel):
...
@@ -127,6 +126,5 @@ class MarlinLinearKernel(MPLinearKernel):
wtype
=
c
.
weight_type
,
wtype
=
c
.
weight_type
,
input_size_per_partition
=
c
.
partition_weight_shape
[
0
],
input_size_per_partition
=
c
.
partition_weight_shape
[
0
],
output_size_per_partition
=
c
.
partition_weight_shape
[
1
],
output_size_per_partition
=
c
.
partition_weight_shape
[
1
],
has_zp
=
self
.
config
.
zero_points
,
is_k_full
=
self
.
is_k_full
,
is_k_full
=
self
.
is_k_full
,
bias
=
bias
)
bias
=
bias
)
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
1d0c9d6b
...
@@ -7,12 +7,15 @@ import torch
...
@@ -7,12 +7,15 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
.quant_utils
import
pack_cols
,
unpack_cols
from
.quant_utils
import
pack_cols
,
unpack_cols
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MIN_THREAD_K
=
128
...
@@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True
...
@@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True
# For binary size and compile time, we don't support the same types for with and
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
# TODO: we may want to move this into the C++ so its closer to the actual impl
def
query_marlin_supported_quant_types
(
has_zp
:
bool
,
def
query_marlin_supported_quant_types
(
device_capability
:
Optional
[
int
]
=
None
has_zp
:
bool
,
):
include_fp_type
:
bool
=
True
,
device_capability
:
Optional
[
int
]
=
None
,
):
if
device_capability
is
None
:
if
device_capability
is
None
:
capability_tuple
=
current_platform
.
get_device_capability
()
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
(
-
1
if
capability_tuple
is
None
else
device_capability
=
(
-
1
if
capability_tuple
is
None
else
...
@@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool,
...
@@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool,
if
has_zp
:
if
has_zp
:
# AWQ style, unsigned + runtime zero-point
# AWQ style, unsigned + runtime zero-point
return
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
return
[
scalar_types
.
uint4
]
else
:
else
:
# GPTQ style, unsigned + symmetric bias
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
res
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
# to add `scalar_types.float8_e4m3fn` here
if
include_fp_type
:
return
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
res
+=
[
scalar_types
.
float8_e4m3fn
]
return
res
def
_check_marlin_supported
(
def
_check_marlin_supported
(
...
@@ -62,7 +68,7 @@ def _check_marlin_supported(
...
@@ -62,7 +68,7 @@ def _check_marlin_supported(
capability_tuple
.
to_int
())
capability_tuple
.
to_int
())
supported_types
=
query_marlin_supported_quant_types
(
supported_types
=
query_marlin_supported_quant_types
(
has_zp
,
device_capability
)
has_zp
,
True
,
device_capability
)
if
quant_type
not
in
supported_types
:
if
quant_type
not
in
supported_types
:
return
(
False
,
f
"Marlin does not support weight_bits =
{
quant_type
}
. "
return
(
False
,
f
"Marlin does not support weight_bits =
{
quant_type
}
. "
...
@@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int,
...
@@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad
=
False
)
requires_grad
=
False
)
def
marlin_make_workspace_new
(
device
:
torch
.
device
,
max_blocks_per_sm
:
int
=
1
)
->
torch
.
Tensor
:
# In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
return
torch
.
zeros
(
sms
*
max_blocks_per_sm
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
def
marlin_is_k_full
(
act_order
:
bool
,
is_row_parallel
:
bool
)
->
bool
:
def
marlin_is_k_full
(
act_order
:
bool
,
is_row_parallel
:
bool
)
->
bool
:
return
(
not
act_order
)
or
(
act_order
and
not
is_row_parallel
)
return
(
not
act_order
)
or
(
act_order
and
not
is_row_parallel
)
...
@@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
...
@@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return
output
return
output
def
maybe_warn_marlin_atomic_add
(
device
,
dtype
):
if
torch
.
compiler
.
is_dynamo_compiling
():
return
device_capability
=
torch
.
cuda
.
get_device_capability
(
device
)
if
device_capability
[
0
]
<
9
and
dtype
==
torch
.
bfloat16
:
logger
.
info_once
(
"You are running Marlin kernel with bf16 on GPUs before SM90. "
"You can consider change to fp16 to achieve better performance "
"if possible."
)
def
maybe_warn_marlin_atomic_add_env
():
if
torch
.
compiler
.
is_dynamo_compiling
():
return
if
envs
.
VLLM_MARLIN_USE_ATOMIC_ADD
:
return
logger
.
info_once
(
"Marlin kernel can achieve better performance for small size_n "
"with experimental use_atomic_add feature. "
"You can consider set environment variable "
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible."
)
def
should_use_atomic_add_reduce
(
m
:
int
,
n
:
int
,
k
:
int
,
device
:
torch
.
device
,
def
should_use_atomic_add_reduce
(
m
:
int
,
n
:
int
,
k
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
bool
:
dtype
:
torch
.
dtype
)
->
bool
:
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
if
n
>=
2048
or
k
<
2048
or
device
.
type
!=
"cuda"
:
return
False
# disable atomicAdd reduce by default,
# disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
if
not
envs
.
VLLM_MARLIN_USE_ATOMIC_ADD
or
device
.
type
!=
"cuda"
:
if
not
envs
.
VLLM_MARLIN_USE_ATOMIC_ADD
:
maybe_warn_marlin_atomic_add_env
()
return
False
return
False
# sm8x doesn't support atomicAdd + bfloat16 natively
# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability
=
torch
.
cuda
.
get_device_capability
(
device
)
device_capability
=
torch
.
cuda
.
get_device_capability
(
device
)
if
device_capability
[
0
]
<
9
and
dtype
==
torch
.
bfloat16
:
if
device_capability
[
0
]
<
9
and
dtype
==
torch
.
bfloat16
:
maybe_warn_marlin_atomic_add
(
device
,
dtype
)
return
False
return
False
# the performance of atomicAdd is better than global reduce
return
True
# only when m*n is small and k is large
return
n
<
2048
and
k
>=
2048
def
apply_gptq_marlin_linear
(
def
apply_gptq_marlin_linear
(
...
@@ -332,7 +378,6 @@ def apply_gptq_marlin_linear(
...
@@ -332,7 +378,6 @@ def apply_gptq_marlin_linear(
wtype
:
ScalarType
,
wtype
:
ScalarType
,
output_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
has_zp
:
bool
,
is_k_full
:
bool
,
is_k_full
:
bool
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
...
@@ -346,6 +391,7 @@ def apply_gptq_marlin_linear(
...
@@ -346,6 +391,7 @@ def apply_gptq_marlin_linear(
dtype
=
input
.
dtype
)
dtype
=
input
.
dtype
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
None
,
weight
,
weight
,
weight_scale
,
weight_scale
,
weight_zp
,
weight_zp
,
...
@@ -358,7 +404,6 @@ def apply_gptq_marlin_linear(
...
@@ -358,7 +404,6 @@ def apply_gptq_marlin_linear(
size_k
=
input_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
is_k_full
=
is_k_full
,
use_atomic_add
=
use_atomic_add
,
use_atomic_add
=
use_atomic_add
,
has_zp
=
has_zp
,
use_fp32_reduce
=
use_fp32_reduce
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
)
is_zp_float
=
False
)
...
@@ -391,6 +436,7 @@ def apply_awq_marlin_linear(
...
@@ -391,6 +436,7 @@ def apply_awq_marlin_linear(
dtype
=
input
.
dtype
)
dtype
=
input
.
dtype
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
None
,
weight
,
weight
,
weight_scale
,
weight_scale
,
weight_zp
,
weight_zp
,
...
@@ -401,8 +447,6 @@ def apply_awq_marlin_linear(
...
@@ -401,8 +447,6 @@ def apply_awq_marlin_linear(
size_m
=
reshaped_x
.
shape
[
0
],
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
True
,
has_zp
=
True
,
use_atomic_add
=
use_atomic_add
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
)
is_zp_float
=
False
)
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
1d0c9d6b
...
@@ -6,9 +6,11 @@ import torch
...
@@ -6,9 +6,11 @@ import torch
import
vllm._custom_ops
as
ops
import
vllm._custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
USE_FP32_REDUCE_DEFAULT
,
marlin_make_workspace_new
,
marlin_permute_scales
,
should_use_atomic_add_reduce
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
.marlin_utils
import
marlin_make_workspace
,
marlin_permute_scales
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -18,30 +20,40 @@ def is_fp8_marlin_supported():
...
@@ -18,30 +20,40 @@ def is_fp8_marlin_supported():
def
apply_fp8_marlin_linear
(
def
apply_fp8_marlin_linear
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_n
:
int
,
size_k
:
int
,
size_k
:
int
,
bias
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
# For GPUs that lack FP8 hardware support, we can leverage the
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
size_n
,
)
out_shape
=
input
.
shape
[:
-
1
]
+
(
size_n
,
)
output
=
ops
.
fp8_marlin_gemm
(
use_atomic_add
=
should_use_atomic_add_reduce
(
m
=
reshaped_x
.
size
(
0
),
a
=
reshaped_x
,
n
=
size_n
,
b_q_weight
=
weight
,
k
=
size_k
,
b_scales
=
weight_scale
,
device
=
input
.
device
,
workspace
=
workspace
,
dtype
=
input
.
dtype
)
num_bits
=
8
,
size_m
=
reshaped_x
.
shape
[
0
],
output
=
ops
.
gptq_marlin_gemm
(
a
=
reshaped_x
,
size_n
=
size_n
,
c
=
None
,
size_k
=
size_k
,
b_q_weight
=
weight
,
)
b_scales
=
weight_scale
,
b_zeros
=
None
,
g_idx
=
None
,
perm
=
None
,
workspace
=
workspace
,
b_q_type
=
scalar_types
.
float8_e4m3fn
,
size_m
=
reshaped_x
.
size
(
0
),
size_n
=
size_n
,
size_k
=
size_k
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
output
.
add_
(
bias
)
# In-place add
...
@@ -50,7 +62,7 @@ def apply_fp8_marlin_linear(
...
@@ -50,7 +62,7 @@ def apply_fp8_marlin_linear(
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
s
trategy
:
str
=
"tensor"
)
->
None
:
s
ize_k_first
:
bool
=
True
)
->
None
:
logger
.
warning_once
(
logger
.
warning_once
(
"Your GPU does not have native support for FP8 computation but "
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"FP8 quantization is being used. Weight-only FP8 compression will "
...
@@ -60,51 +72,234 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -60,51 +72,234 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
part_size_n
=
layer
.
output_size_per_partition
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
if
size_k_first
:
assert
layer
.
weight
.
shape
==
(
part_size_k
,
part_size_n
)
else
:
assert
layer
.
weight
.
shape
==
(
part_size_n
,
part_size_k
)
device
=
layer
.
weight
.
device
device
=
layer
.
weight
.
device
# WORKSPACE
# WORKSPACE
layer
.
workspace
=
marlin_make_workspace
(
part_size_n
,
device
)
layer
.
workspace
=
marlin_make_workspace
_new
(
device
)
# WEIGHT
# WEIGHT
# Repack weights to marlin format
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
pack_fp8_to_int32
(
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
layer
.
weight
),
qweight
=
pack_fp8_to_int32
(
layer
.
weight
,
size_k_first
)
perm
=
torch
.
empty
(
0
,
if
not
size_k_first
:
dtype
=
torch
.
int
,
qweight
=
qweight
.
T
.
contiguous
()
device
=
device
),
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
qweight
,
perm
=
perm
,
size_k
=
part_size_k
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
size_n
=
part_size_n
,
num_bits
=
8
)
num_bits
=
8
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# WEIGHT SCALES
scales
=
layer
.
weight_scale
.
to
(
layer
.
orig_dtype
)
# Permute scales
# Permute scales
if
"weight_scale"
in
dir
(
layer
):
scales
=
layer
.
weight_scale
.
to
(
layer
.
orig_dtype
)
elif
"weight_scale_inv"
in
dir
(
layer
):
scales
=
layer
.
weight_scale_inv
.
to
(
layer
.
orig_dtype
)
del
layer
.
weight_scale_inv
if
layer
.
weight_block_size
is
None
:
group_size
=
-
1
else
:
group_size
=
layer
.
weight_block_size
[
1
]
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if
layer
.
weight_block_size
is
None
:
if
scales
.
nelement
()
==
1
:
# tensor-wise quantization -> channel-wise quantization
# (1, 1) =>(repeat)=> (1, size_n)
scales
=
scales
.
view
(
1
,
1
).
repeat_interleave
(
part_size_n
,
1
)
elif
scales
.
nelement
()
>
1
and
scales
.
nelement
()
!=
part_size_n
:
assert
part_size_n
%
scales
.
nelement
()
==
0
s_size
=
scales
.
nelement
()
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (1, s_size) =>(repeat)=> (1, size_n)
scales
=
scales
.
view
(
1
,
s_size
)
scales
=
scales
.
repeat_interleave
(
part_size_n
//
s_size
,
1
)
else
:
# channel-wise quantization
# (1, size_n)
scales
=
scales
.
view
(
1
,
part_size_n
)
else
:
# block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n)
block_n
=
layer
.
weight_block_size
[
0
]
scales
=
scales
.
T
.
repeat_interleave
(
block_n
,
1
)
# size_n may not divisible by block_size[0]
scales
=
scales
[:,
:
part_size_n
]
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
size_n
=
part_size_n
,
group_size
=
-
1
)
group_size
=
group_size
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
marlin_scales
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
marlin_scales
,
requires_grad
=
False
)
def
pack_fp8_to_int32
(
fp8_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
prepare_moe_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
size_k_first
:
bool
=
True
)
->
None
:
logger
.
warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
e
=
layer
.
num_experts
k
=
layer
.
hidden_size
n
=
layer
.
intermediate_size_per_partition
# WORKSPACE
device
=
layer
.
w13_weight
.
device
layer
.
workspace
=
marlin_make_workspace_new
(
device
,
4
)
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
# WEIGHT
# Repack weights to marlin format
for
name
in
[
"w13_weight"
,
"w2_weight"
]:
weight
=
getattr
(
layer
,
name
)
tensor_list
=
[]
if
"w13"
in
name
:
size_n
,
size_k
=
n
*
2
,
k
else
:
size_n
,
size_k
=
k
,
n
if
size_k_first
:
assert
weight
.
shape
==
(
e
,
size_k
,
size_n
)
else
:
assert
weight
.
shape
==
(
e
,
size_n
,
size_k
)
for
i
in
range
(
e
):
qweight
=
pack_fp8_to_int32
(
weight
[
i
],
size_k_first
)
if
not
size_k_first
:
qweight
=
qweight
.
T
.
contiguous
()
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
qweight
,
perm
=
perm
,
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
)
tensor_list
.
append
(
marlin_qweight
)
weight
=
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
setattr
(
layer
,
name
,
weight
)
# WEIGHT SCALES
# Permute scales
if
layer
.
weight_block_size
is
None
:
group_size
=
-
1
else
:
group_size
=
layer
.
weight_block_size
[
1
]
for
name
in
[
"w13"
,
"w2"
]:
if
name
+
"_weight_scale"
in
dir
(
layer
):
new_name
=
name
+
"_weight_scale"
scales
=
getattr
(
layer
,
new_name
).
to
(
layer
.
orig_dtype
)
delattr
(
layer
,
new_name
)
elif
name
+
"_weight_scale_inv"
in
dir
(
layer
):
new_name
=
name
+
"_weight_scale_inv"
scales
=
getattr
(
layer
,
new_name
).
to
(
layer
.
orig_dtype
)
delattr
(
layer
,
new_name
)
tensor_list
=
[]
if
"w13"
in
name
:
size_n
,
size_k
=
n
*
2
,
k
else
:
size_n
,
size_k
=
k
,
n
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if
layer
.
weight_block_size
is
None
:
if
scales
.
nelement
()
==
e
:
# tensor-wise quantization -> channel-wise quantization
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
scales
=
scales
.
view
(
e
,
1
,
1
).
repeat_interleave
(
size_n
,
2
)
elif
scales
.
nelement
()
>
e
and
scales
.
nelement
()
!=
e
*
size_n
:
assert
(
e
*
size_n
)
%
scales
.
nelement
()
==
0
s_size
=
scales
.
nelement
()
//
e
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
scales
=
scales
.
view
(
e
,
1
,
s_size
)
scales
=
scales
.
repeat_interleave
(
size_n
//
s_size
,
2
)
else
:
# channel-wise quantization
# (e, 1, size_n)
scales
=
scales
.
view
(
e
,
1
,
size_n
)
else
:
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
block_n
=
layer
.
weight_block_size
[
0
]
scales
=
scales
.
permute
(
0
,
2
,
1
).
repeat_interleave
(
block_n
,
2
)
# size_n may not divisible by block_size[0]
scales
=
scales
[...,
:
size_n
].
contiguous
()
for
i
in
range
(
e
):
marlin_scales
=
marlin_permute_scales
(
s
=
scales
[
i
],
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=
group_size
)
tensor_list
.
append
(
marlin_scales
)
scales
=
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
scales
=
torch
.
nn
.
Parameter
(
scales
,
requires_grad
=
False
)
setattr
(
layer
,
name
+
"_weight_scale"
,
scales
)
def
pack_fp8_to_int32
(
fp8_tensor
:
torch
.
Tensor
,
size_k_first
:
bool
=
True
)
->
torch
.
Tensor
:
"""
"""
Repack FP8 weights to gptq format (packed int32 elements)
Repack FP8 weights to gptq format (packed int32 elements)
"""
"""
assert
fp8_tensor
.
dtype
==
torch
.
float8_e4m3fn
assert
fp8_tensor
.
dtype
==
torch
.
float8_e4m3fn
assert
fp8_tensor
.
shape
[
0
]
%
4
==
0
assert
fp8_tensor
.
ndim
==
2
fp8_tensor
=
fp8_tensor
.
T
if
size_k_first
else
fp8_tensor
fp8_tensor
=
fp8_tensor
.
contiguous
()
# fp8_tensor is contiguous and have shape (N, K) now
# with `.view(torch.int32)`, it become (N, K // 4)
int32_tensor
=
fp8_tensor
.
view
(
torch
.
int32
)
return
int32_tensor
.
T
.
contiguous
()
if
size_k_first
else
int32_tensor
# Reshape to prepare for packing
def
marlin_quant_fp8_torch
(
weight
,
group_size
):
reshaped
=
fp8_tensor
.
reshape
(
-
1
,
4
,
*
fp8_tensor
.
shape
[
1
:])
size_n
,
size_k
=
weight
.
shape
device
=
weight
.
device
# Convert fp8 to uint8 (byte) representation
if
group_size
!=
-
1
:
byte_tensor
=
reshaped
.
view
(
torch
.
uint8
)
scales
=
weight
.
view
(
size_n
,
-
1
,
group_size
).
abs
().
max
(
-
1
)[
0
]
/
448
repeated_scales
=
scales
.
repeat_interleave
(
group_size
,
1
)
fp8_weight
=
(
weight
/
repeated_scales
).
to
(
torch
.
float8_e4m3fn
)
weight_ref
=
fp8_weight
.
to
(
weight
.
dtype
)
*
repeated_scales
else
:
scales
=
weight
.
view
(
size_n
,
1
,
group_size
).
abs
().
max
(
-
1
)[
0
]
/
448
repeated_scales
=
scales
.
repeat_interleave
(
size_k
,
1
)
fp8_weight
=
(
weight
/
repeated_scales
).
to
(
torch
.
float8_e4m3fn
)
weight_ref
=
fp8_weight
.
to
(
weight
.
dtype
)
*
repeated_scales
packed_weight
=
pack_fp8_to_int32
(
fp8_weight
,
False
).
T
.
contiguous
()
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_weight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
,
)
# Pack 4 uint8 values into one int32
marlin_scales
=
marlin_permute_scales
(
s
=
scales
.
T
,
packed
=
(
byte_tensor
[:,
0
].
to
(
torch
.
int32
)
|
size_k
=
size_k
,
(
byte_tensor
[:,
1
].
to
(
torch
.
int32
)
<<
8
)
|
size_n
=
size_n
,
(
byte_tensor
[:,
2
].
to
(
torch
.
int32
)
<<
16
)
|
group_size
=
group_size
)
(
byte_tensor
[:,
3
].
to
(
torch
.
int32
)
<<
24
))
return
packed
.
view
(
fp8_tensor
.
shape
[
0
]
//
4
,
return
weight_ref
.
T
,
marlin_qweight
,
marlin_scales
*
fp8_tensor
.
shape
[
1
:]).
contiguous
()
vllm/scalar_type.py
View file @
1d0c9d6b
...
@@ -6,6 +6,8 @@ from dataclasses import dataclass
...
@@ -6,6 +6,8 @@ from dataclasses import dataclass
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
_SCALAR_TYPES_ID_MAP
=
{}
# Mirrors enum in `core/scalar_type.hpp`
# Mirrors enum in `core/scalar_type.hpp`
class
NanRepr
(
Enum
):
class
NanRepr
(
Enum
):
...
@@ -158,6 +160,8 @@ class ScalarType:
...
@@ -158,6 +160,8 @@ class ScalarType:
assert
offset
<=
64
,
\
assert
offset
<=
64
,
\
f
"ScalarType fields too big
{
offset
}
to fit into an int64"
f
"ScalarType fields too big
{
offset
}
to fit into an int64"
_SCALAR_TYPES_ID_MAP
[
val
]
=
self
return
val
return
val
@
property
@
property
...
@@ -295,6 +299,13 @@ class ScalarType:
...
@@ -295,6 +299,13 @@ class ScalarType:
ret
.
id
# noqa B018: make sure the id is cached
ret
.
id
# noqa B018: make sure the id is cached
return
ret
return
ret
@
classmethod
def
from_id
(
cls
,
scalar_type_id
:
int
):
if
scalar_type_id
not
in
_SCALAR_TYPES_ID_MAP
:
raise
ValueError
(
f
"scalar_type_id
{
scalar_type_id
}
doesn't exists."
)
return
_SCALAR_TYPES_ID_MAP
[
scalar_type_id
]
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# for floating point types (leading f) the scheme is:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment