Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bda9d053
Unverified
Commit
bda9d053
authored
Jul 27, 2025
by
Wentao Ye
Committed by
GitHub
Jul 27, 2025
Browse files
[Refactor] Refactor MOE NVFP4 Code Base: ModelOpt + Compressed Tensor (#21631)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
3d847a31
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
75 additions
and
106 deletions
+75
-106
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+3
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+8
-31
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+5
-61
vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
...ecutor/layers/quantization/utils/nvfp4_emulation_utils.py
+3
-12
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+55
-0
No files found.
tests/quantization/test_compressed_tensors.py
View file @
bda9d053
...
@@ -17,7 +17,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
...
@@ -17,7 +17,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A16Fp4
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A16Fp4
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
,
cutlass_fp4_supported
)
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
cutlass_fp4_supported
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
sparse_cutlass_supported
)
sparse_cutlass_supported
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
bda9d053
...
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
...
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target
,
is_activation_quantization_format
,
find_matched_target
,
is_activation_quantization_format
,
should_ignore_layer
)
should_ignore_layer
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.
nvfp4_emulation
_utils
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.utils.
quant
_utils
import
(
cutlass_fp4_supported
)
cutlass_fp4_supported
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
bda9d053
...
@@ -27,8 +27,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
...
@@ -27,8 +27,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin
)
prepare_moe_fp4_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_moe_fp8_layer_for_marlin
)
prepare_moe_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.
nvfp4_emulation
_utils
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.utils.
quant
_utils
import
(
cutlass_fp4_supported
)
cutlass_fp4_supported
,
swizzle_blockscale
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -193,29 +193,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
...
@@ -193,29 +193,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
def
swizzle_blockscale
(
self
,
scale
:
torch
.
tensor
):
assert
(
scale
.
dtype
==
torch
.
float8_e4m3fn
)
# Pad and blockwise interleave weight_scale
scale_ndim
=
scale
.
ndim
if
scale
.
ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
assert
scale
.
ndim
==
3
B
,
M
,
K
=
scale
.
shape
round_up_multiple
=
lambda
x
,
m
:
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
round_up_multiple
(
M
,
128
)
K_padded
=
round_up_multiple
(
K
,
4
)
padded_scale
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
)
padded_scale
[:
B
,
:
M
,
:
K
]
=
scale
batches
,
rows
,
cols
=
padded_scale
.
shape
assert
rows
%
128
==
0
assert
cols
%
4
==
0
padded_scale
=
padded_scale
.
reshape
(
batches
,
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
swizzled_scale
=
padded_scale
.
permute
((
0
,
1
,
4
,
3
,
2
,
5
))
swizzled_scale
=
swizzled_scale
.
contiguous
().
cuda
()
return
(
swizzled_scale
.
reshape
(
M
,
K
)
if
scale_ndim
==
2
else
swizzled_scale
.
reshape
(
B
,
M
,
K
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# From packed to weight
# From packed to weight
...
@@ -243,13 +220,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
...
@@ -243,13 +220,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
return
return
# swizzle weight scales
# swizzle weight scales
layer
.
w13_blockscale_swizzled
=
torch
.
nn
.
Parameter
(
layer
.
w13_blockscale_swizzled
=
torch
.
nn
.
Parameter
(
swizzle_blockscale
(
self
.
swizzle_blockscale
(
layer
.
w13_weight_scale
),
layer
.
w13_weight_scale
),
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
w2_blockscale_swizzled
=
torch
.
nn
.
Parameter
(
layer
.
w2_blockscale_swizzled
=
torch
.
nn
.
Parameter
(
swizzle_blockscale
(
self
.
swizzle_blockscale
(
layer
.
w2_weight_scale
),
layer
.
w2_weight_scale
),
requires_grad
=
False
)
requires_grad
=
False
)
# w13
# w13
w13_input_global_scale
=
layer
.
w13_input_global_scale
.
max
(
w13_input_global_scale
=
layer
.
w13_input_global_scale
.
max
(
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
bda9d053
...
@@ -9,8 +9,7 @@ from torch.nn.parameter import Parameter
...
@@ -9,8 +9,7 @@ from torch.nn.parameter import Parameter
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm._custom_ops
import
(
cutlass_scaled_fp4_mm
,
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
cutlass_scaled_mm_supports_fp4
,
scaled_fp4_quant
)
from
vllm.distributed
import
get_ep_group
from
vllm.distributed
import
get_ep_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEParallelConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEParallelConfig
...
@@ -28,7 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
...
@@ -28,7 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear
,
is_fp4_marlin_supported
,
apply_fp4_marlin_linear
,
is_fp4_marlin_supported
,
prepare_fp4_layer_for_marlin
,
prepare_moe_fp4_layer_for_marlin
)
prepare_fp4_layer_for_marlin
,
prepare_moe_fp4_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
is_layer_skipped
)
GroupShape
,
cutlass_fp4_supported
,
is_layer_skipped
,
swizzle_blockscale
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
requantize_with_max_scale
)
Fp8LinearOp
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
...
@@ -667,14 +666,6 @@ class ModelOptNvFp4Config(QuantizationConfig):
...
@@ -667,14 +666,6 @@ class ModelOptNvFp4Config(QuantizationConfig):
return
None
return
None
def
cutlass_fp4_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
cutlass_scaled_mm_supports_fp4
(
capability
)
class
ModelOptFp8KVCacheMethod
(
BaseKVCacheMethod
):
class
ModelOptFp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
Supports loading kv-cache scaling factors from FP8 checkpoints.
...
@@ -772,29 +763,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -772,29 +763,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
swizzle_blockscale
(
self
,
scale
:
torch
.
tensor
):
assert
(
scale
.
dtype
==
torch
.
float8_e4m3fn
)
# Pad and blockwise interleave weight_scale
scale_ndim
=
scale
.
ndim
if
scale
.
ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
assert
scale
.
ndim
==
3
B
,
M
,
K
=
scale
.
shape
round_up_multiple
=
lambda
x
,
m
:
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
round_up_multiple
(
M
,
128
)
K_padded
=
round_up_multiple
(
K
,
4
)
padded_scale
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
)
padded_scale
[:
B
,
:
M
,
:
K
]
=
scale
batches
,
rows
,
cols
=
padded_scale
.
shape
assert
rows
%
128
==
0
assert
cols
%
4
==
0
padded_scale
=
padded_scale
.
reshape
(
batches
,
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
swizzled_scale
=
padded_scale
.
permute
((
0
,
1
,
4
,
3
,
2
,
5
))
swizzled_scale
=
swizzled_scale
.
contiguous
().
cuda
()
return
(
swizzled_scale
.
reshape
(
M
,
K
)
if
scale_ndim
==
2
else
swizzled_scale
.
reshape
(
B
,
M
,
K
))
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# global scales:
# global scales:
...
@@ -814,7 +782,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -814,7 +782,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
"Expected weight_scale.dim(1) to be divisible by 16"
)
"Expected weight_scale.dim(1) to be divisible by 16"
)
assert
(
layer
.
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
assert
(
layer
.
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
"Weight Block scale must be represented as FP8-E4M3"
)
"Weight Block scale must be represented as FP8-E4M3"
)
swizzled_weight_scale
=
self
.
swizzle_blockscale
(
layer
.
weight_scale
)
swizzled_weight_scale
=
swizzle_blockscale
(
layer
.
weight_scale
)
layer
.
weight_scale_swizzled
=
Parameter
(
swizzled_weight_scale
,
layer
.
weight_scale_swizzled
=
Parameter
(
swizzled_weight_scale
,
requires_grad
=
False
)
requires_grad
=
False
)
...
@@ -1060,29 +1028,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
...
@@ -1060,29 +1028,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
swizzle_blockscale
(
self
,
scale
:
torch
.
tensor
):
assert
(
scale
.
dtype
==
torch
.
float8_e4m3fn
)
# Pad and blockwise interleave weight_scale
scale_ndim
=
scale
.
ndim
if
scale
.
ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
assert
scale
.
ndim
==
3
B
,
M
,
K
=
scale
.
shape
round_up_multiple
=
lambda
x
,
m
:
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
round_up_multiple
(
M
,
128
)
K_padded
=
round_up_multiple
(
K
,
4
)
padded_scale
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
)
padded_scale
[:
B
,
:
M
,
:
K
]
=
scale
batches
,
rows
,
cols
=
padded_scale
.
shape
assert
rows
%
128
==
0
assert
cols
%
4
==
0
padded_scale
=
padded_scale
.
reshape
(
batches
,
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
swizzled_scale
=
padded_scale
.
permute
((
0
,
1
,
4
,
3
,
2
,
5
))
swizzled_scale
=
swizzled_scale
.
contiguous
().
cuda
()
return
(
swizzled_scale
.
reshape
(
M
,
K
)
if
scale_ndim
==
2
else
swizzled_scale
.
reshape
(
B
,
M
,
K
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# GEMM 1
# GEMM 1
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
...
@@ -1128,8 +1073,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
...
@@ -1128,8 +1073,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"Expected weight_scale.dim(1) to be divisible by 16"
)
"Expected weight_scale.dim(1) to be divisible by 16"
)
assert
(
layer
.
w13_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
assert
(
layer
.
w13_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
"Weight Blockscale must be represented as FP8-E4M3"
)
"Weight Blockscale must be represented as FP8-E4M3"
)
w13_blockscale_swizzled
=
self
.
swizzle_blockscale
(
w13_blockscale_swizzled
=
swizzle_blockscale
(
layer
.
w13_weight_scale
)
layer
.
w13_weight_scale
)
layer
.
w13_blockscale_swizzled
=
Parameter
(
w13_blockscale_swizzled
,
layer
.
w13_blockscale_swizzled
=
Parameter
(
w13_blockscale_swizzled
,
requires_grad
=
False
)
requires_grad
=
False
)
...
@@ -1151,7 +1095,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
...
@@ -1151,7 +1095,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"Expected weight_scale.dim(1) to be divisible by 16"
)
"Expected weight_scale.dim(1) to be divisible by 16"
)
assert
(
layer
.
w2_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
assert
(
layer
.
w2_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
"Weight Blockscale must be represented as FP8-E4M3"
)
"Weight Blockscale must be represented as FP8-E4M3"
)
w2_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w2_weight_scale
)
w2_blockscale_swizzled
=
swizzle_blockscale
(
layer
.
w2_weight_scale
)
layer
.
w2_blockscale_swizzled
=
Parameter
(
w2_blockscale_swizzled
,
layer
.
w2_blockscale_swizzled
=
Parameter
(
w2_blockscale_swizzled
,
requires_grad
=
False
)
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
View file @
bda9d053
...
@@ -2,13 +2,12 @@
...
@@ -2,13 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch
from
vllm._custom_ops
import
cutlass_scaled_mm_supports_fp4
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
__all__
=
[
__all__
=
[
"break_fp4_bytes"
,
"dequantize_to_dtype"
,
"ref_nvfp4_quant"
,
"break_fp4_bytes"
,
"cutlass_fp4_supported"
"dequantize_to_dtype"
,
"ref_nvfp4_quant"
,
]
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
...
@@ -17,14 +16,6 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
...
@@ -17,14 +16,6 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
)
def
cutlass_fp4_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
cutlass_scaled_mm_supports_fp4
(
capability
)
def
break_fp4_bytes
(
a
,
dtype
):
def
break_fp4_bytes
(
a
,
dtype
):
assert
a
.
dtype
==
torch
.
uint8
assert
a
.
dtype
==
torch
.
uint8
m
,
n
=
a
.
shape
m
,
n
=
a
.
shape
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
bda9d053
...
@@ -8,8 +8,10 @@ from typing import ClassVar, NamedTuple, Optional
...
@@ -8,8 +8,10 @@ from typing import ClassVar, NamedTuple, Optional
import
numpy
import
numpy
import
torch
import
torch
from
vllm._custom_ops
import
cutlass_scaled_mm_supports_fp4
from
vllm.model_executor.layers.quantization.qqq
import
(
from
vllm.model_executor.layers.quantization.qqq
import
(
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
...
@@ -592,3 +594,56 @@ def awq_pack(
...
@@ -592,3 +594,56 @@ def awq_pack(
q_w
=
q_w
.
reshape
((
-
1
,
size_n
)).
contiguous
()
q_w
=
q_w
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
pack_cols
(
q_w
,
num_bits
,
size_k
,
size_n
)
return
pack_cols
(
q_w
,
num_bits
,
size_k
,
size_n
)
def
swizzle_blockscale
(
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Pad and block-interleave the FP4 block-scales so that they match the data
layout expected by the CUTLASS / FlashInfer kernels.
Parameters
----------
scale: torch.Tensor
Returns
-------
torch.Tensor
The swizzled tensor with the same logical shape as *scale*.
"""
assert
scale
.
dtype
==
torch
.
float8_e4m3fn
,
(
"swizzle_blockscale expects the input tensor to be in "
"torch.float8_e4m3fn format."
)
scale_ndim
=
scale
.
ndim
if
scale_ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
# (1, M, K)
assert
scale
.
ndim
==
3
,
"Expected a 2-D or 3-D tensor for block scales."
B
,
M
,
K
=
scale
.
shape
def
_round_up
(
x
:
int
,
m
:
int
)
->
int
:
return
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
_round_up
(
M
,
128
)
K_padded
=
_round_up
(
K
,
4
)
padded
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
,
device
=
scale
.
device
)
padded
[:
B
,
:
M
,
:
K
]
=
scale
# Reshape / permute to the layout required by the kernel.
padded
=
padded
.
reshape
(
B
,
M_padded
//
128
,
4
,
32
,
K_padded
//
4
,
4
)
swizzled
=
padded
.
permute
(
0
,
1
,
4
,
3
,
2
,
5
).
contiguous
().
cuda
()
if
scale_ndim
==
2
:
return
swizzled
.
reshape
(
M
,
K
)
return
swizzled
.
reshape
(
B
,
M
,
K
)
def
cutlass_fp4_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
cutlass_scaled_mm_supports_fp4
(
capability
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment