Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
debd6bbf
Unverified
Commit
debd6bbf
authored
Mar 11, 2025
by
Pavani Majety
Committed by
GitHub
Mar 12, 2025
Browse files
[Kernel] Add ModelOpt FP4 Checkpoint Support (#12520)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
5c538c37
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
388 additions
and
30 deletions
+388
-30
csrc/ops.h
csrc/ops.h
+5
-3
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
+6
-0
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
+4
-3
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+4
-0
tests/models/decoder_only/language/test_nvfp4.py
tests/models/decoder_only/language/test_nvfp4.py
+82
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+4
-0
vllm/config.py
vllm/config.py
+1
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+17
-6
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-1
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+262
-16
No files found.
csrc/ops.h
View file @
debd6bbf
...
...
@@ -160,14 +160,16 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
#ifndef USE_ROCM
bool
cutlass_scaled_mm_supports_fp4
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_block_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_block_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
...
...
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
View file @
debd6bbf
...
...
@@ -36,3 +36,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
"be compiled using CUDA 12.8 and target "
"compute capability 100 or above."
);
}
bool
cutlass_scaled_mm_supports_fp4
(
int64_t
cuda_device_capability
)
{
int
runtimeVersion
;
cudaRuntimeGetVersion
(
&
runtimeVersion
);
return
cuda_device_capability
>=
100
&&
runtimeVersion
>=
12080
;
}
\ No newline at end of file
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
View file @
debd6bbf
...
...
@@ -201,10 +201,11 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
TORCH_CHECK(x.is_contiguous(), m, "
:
must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
...
...
csrc/torch_bindings.cpp
View file @
debd6bbf
...
...
@@ -434,6 +434,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! output_scale, Tensor input_scale) -> ()"
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops
.
def
(
"cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"
);
ops
.
impl
(
"cutlass_scaled_mm_supports_fp4"
,
&
cutlass_scaled_mm_supports_fp4
);
#endif
// Quantized GEMM for GPTQ.
...
...
tests/models/decoder_only/language/test_nvfp4.py
0 → 100644
View file @
debd6bbf
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
"""Tests Model Optimizer nvfp4 models against ground truth generation
Note: these tests will only pass on B200
"""
import
os
from
typing
import
List
import
pytest
from
transformers
import
AutoTokenizer
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
LLM
,
SamplingParams
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"true"
MAX_MODEL_LEN
=
1024
MODELS
=
[
"nvidia/Llama-3.3-70B-Instruct-FP4"
]
EXPECTED_STRS_MAP
=
{
"nvidia/Llama-3.3-70B-Instruct-FP4"
:
[
'vLLM (Vectorized Large Language Model) is indeed a high-throughput and memory-efficient inference'
,
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to '
,
'Artificial intelligence (AI) and human intelligence (HI) are two distinct forms of intelligence that process'
,
'A neural network is a type of machine learning model inspired by the structure and function of the human brain'
,
'In the heart of a cutting-edge robotics lab, a team of engineers had been working tirelessly to push'
,
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models, leading'
,
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of'
,
'Here are the translations:
\n\n
* Japanese: (Sasuga no tori ga miwa o ts'
]
}
# This test compares against golden strings for exact match since
# there is no baseline implementation to compare against
# and is unstable w.r.t specifics of the fp4 implementation or
# the hardware being run on.
# Disabled to prevent it from breaking the build
@
pytest
.
mark
.
skip
(
reason
=
"Prevent unstable test based on golden strings from breaking the build "
" and test input model being too large and hanging the system."
)
@
pytest
.
mark
.
quant_model
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"nvfp4"
),
reason
=
"nvfp4 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS
)
def
test_models
(
example_prompts
,
model_name
)
->
None
:
model
=
LLM
(
model
=
model_name
,
max_model_len
=
MAX_MODEL_LEN
,
trust_remote_code
=
True
,
enforce_eager
=
True
,
quantization
=
"nvfp4"
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
formatted_prompts
=
[
tokenizer
.
apply_chat_template
([{
"role"
:
"user"
,
"content"
:
prompt
}],
tokenize
=
False
,
add_generation_prompt
=
True
)
for
prompt
in
example_prompts
]
params
=
SamplingParams
(
max_tokens
=
20
,
temperature
=
0
)
generations
:
List
[
str
]
=
[]
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for
prompt
in
formatted_prompts
:
outputs
=
model
.
generate
(
prompt
,
params
)
generations
.
append
(
outputs
[
0
].
outputs
[
0
].
text
)
del
model
print
(
model_name
,
generations
)
expected_strs
=
EXPECTED_STRS_MAP
[
model_name
]
for
i
in
range
(
len
(
example_prompts
)):
generated_str
=
generations
[
i
]
expected_str
=
expected_strs
[
i
]
assert
expected_str
==
generated_str
,
(
f
"Test
{
i
}
:
\n
Expected:
{
expected_str
!
r
}
\n
vLLM:
{
generated_str
!
r
}
"
)
vllm/_custom_ops.py
View file @
debd6bbf
...
...
@@ -467,6 +467,10 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
# cutlass
def
cutlass_scaled_mm_supports_fp4
(
cuda_device_capability
:
int
)
->
bool
:
return
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp4
(
cuda_device_capability
)
def
cutlass_scaled_fp4_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
block_scale_a
:
torch
.
Tensor
,
block_scale_b
:
torch
.
Tensor
,
alpha
:
torch
.
Tensor
,
...
...
vllm/config.py
View file @
debd6bbf
...
...
@@ -613,7 +613,7 @@ class ModelConfig:
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"experts_int8"
,
"quark"
"compressed-tensors"
,
"experts_int8"
,
"quark"
,
"nvfp4"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
vllm/model_executor/layers/linear.py
View file @
debd6bbf
...
...
@@ -30,12 +30,23 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
,
"IPEXAWQLinearMethod"
,
"IPEXGPTQLinearMethod"
,
"HQQMarlinMethod"
,
"QuarkLinearMethod"
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
,
"IPEXAWQLinearMethod"
,
"IPEXGPTQLinearMethod"
,
"HQQMarlinMethod"
,
"QuarkLinearMethod"
,
"ModelOptNvFp4LinearMethod"
,
]
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
debd6bbf
...
...
@@ -14,6 +14,7 @@ QUANTIZATION_METHODS: List[str] = [
"ptpc_fp8"
,
"fbgemm_fp8"
,
"modelopt"
,
"nvfp4"
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
,
...
...
@@ -97,7 +98,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from
.hqq_marlin
import
HQQMarlinConfig
from
.ipex_quant
import
IPEXConfig
from
.marlin
import
MarlinConfig
from
.modelopt
import
ModelOptFp8Config
from
.modelopt
import
ModelOptFp8Config
,
ModelOptNvFp4Config
from
.moe_wna16
import
MoeWNA16Config
from
.neuron_quant
import
NeuronQuantConfig
from
.ptpc_fp8
import
PTPCFp8Config
...
...
@@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"nvfp4"
:
ModelOptNvFp4Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
debd6bbf
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm._custom_ops
import
(
cutlass_scaled_fp4_mm
,
cutlass_scaled_mm_supports_fp4
,
scaled_fp4_quant
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
PerTensorScaleParameter
)
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
ACTIVATION_SCHEMES
=
[
"static"
]
QUANT_ALGOS
=
[
"FP8"
,
"NVFP4"
]
KV_CACHE_QUANT_ALGOS
=
[
"FP8"
]
class
ModelOptFp8Config
(
QuantizationConfig
):
...
...
@@ -54,12 +61,13 @@ class ModelOptFp8Config(QuantizationConfig):
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ModelOptFp8Config"
:
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
is_checkpoint_fp8_serialized
=
(
"FP8"
in
quant_method
)
if
not
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"ModelOpt currently only supports static FP8 "
"quantization in vLLM. Please check the "
if
quant_method
not
in
QUANT_ALGOS
:
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
" quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
is_checkpoint_fp8_serialized
=
(
"FP8"
in
quant_method
)
return
cls
(
is_checkpoint_fp8_serialized
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -72,15 +80,6 @@ class ModelOptFp8Config(QuantizationConfig):
return
None
class
ModelOptFp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
super
().
__init__
(
quant_config
)
class
ModelOptFp8LinearMethod
(
LinearMethodBase
):
"""Linear method for Model Optimizer static quantization.
Supports loading FP8 checkpoints with static weight scale and
...
...
@@ -162,3 +161,250 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
class
ModelOptNvFp4Config
(
QuantizationConfig
):
"""Config class for ModelOpt FP4."""
def
__init__
(
self
,
is_checkpoint_nvfp4_serialized
:
bool
,
kv_cache_quant_algo
:
str
,
exclude_modules
:
List
[
str
],
group_size
:
int
=
16
,
)
->
None
:
self
.
is_checkpoint_nvfp4_serialized
=
is_checkpoint_nvfp4_serialized
if
is_checkpoint_nvfp4_serialized
:
logger
.
warning
(
"Detected ModelOpt NVFP4 checkpoint. Please note that"
" the format is experimental and could change in future."
)
self
.
group_size
=
group_size
self
.
kv_cache_quant_algo
=
kv_cache_quant_algo
self
.
exclude_modules
=
exclude_modules
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"modelopt_nvfp4"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float8_e4m3fn
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
100
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ModelOptNvFp4Config"
:
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
if
quant_method
not
in
QUANT_ALGOS
:
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
" quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
is_checkpoint_nvfp4_serialized
=
(
"NVFP4"
in
quant_method
)
kv_cache_quant_algo
=
quant_config
[
"kv_cache_quant_algo"
]
group_size
=
quant_config
[
"group_size"
]
exclude_modules
=
quant_config
[
"exclude_modules"
]
if
not
(
group_size
and
kv_cache_quant_algo
and
exclude_modules
):
raise
ValueError
(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json"
)
return
cls
(
is_checkpoint_nvfp4_serialized
,
kv_cache_quant_algo
,
exclude_modules
,
group_size
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
):
return
UnquantizedLinearMethod
()
return
ModelOptNvFp4LinearMethod
(
self
)
elif
isinstance
(
layer
,
Attention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
None
def
cutlass_fp4_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
cutlass_scaled_mm_supports_fp4
(
capability
)
class
ModelOptFp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def
__init__
(
self
,
quant_config
:
Union
[
ModelOptFp8Config
,
ModelOptNvFp4Config
]):
super
().
__init__
(
quant_config
)
class
ModelOptNvFp4LinearMethod
(
LinearMethodBase
):
"""Linear method for Model Optimizer NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
input_scale: torch.float32, scalar ,
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
weight_scale_2: torch.float32, scalar,
Args: quant_config: The ModelOpt quantization config.
"""
def
__init__
(
self
,
quant_config
:
ModelOptNvFp4Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_nvfp4_supported
=
cutlass_fp4_supported
()
if
not
self
.
cutlass_nvfp4_supported
:
raise
ValueError
(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and above."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
if
not
self
.
quant_config
.
is_checkpoint_nvfp4_serialized
:
raise
ValueError
(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
if
(
input_size_per_partition
%
16
!=
0
):
raise
ValueError
(
"Unsupported model when in features size is "
"not multiple of 16"
)
# The nvfp4 weight is still represented as
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_nvfp4_serialized
else
params_dtype
)
# Weight
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
# 2 fp4 items are packed in the input dimension
layer
.
output_size_per_partition
,
layer
.
input_size_per_partition
//
2
,
dtype
=
torch
.
uint8
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
# Input Weight Scale
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
# Global Weight Scale
weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale_2"
,
weight_scale_2
)
# Per Block Weight Scale
weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
dtype
=
weight_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
swizzle_blockscale
(
self
,
scale
:
torch
.
tensor
):
assert
(
scale
.
dtype
==
torch
.
float8_e4m3fn
)
# Pad and blockwise interleave weight_scale
scale_ndim
=
scale
.
ndim
if
scale
.
ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
assert
scale
.
ndim
==
3
B
,
M
,
K
=
scale
.
shape
round_up_multiple
=
lambda
x
,
m
:
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
round_up_multiple
(
M
,
128
)
K_padded
=
round_up_multiple
(
K
,
4
)
padded_scale
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
)
padded_scale
[:
B
,
:
M
,
:
K
]
=
scale
batches
,
rows
,
cols
=
padded_scale
.
shape
assert
rows
%
128
==
0
assert
cols
%
4
==
0
padded_scale
=
padded_scale
.
reshape
(
batches
,
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
swizzled_scale
=
padded_scale
.
permute
((
0
,
1
,
4
,
3
,
2
,
5
))
swizzled_scale
=
swizzled_scale
.
contiguous
().
cuda
()
return
(
swizzled_scale
.
reshape
(
M
,
K
)
if
scale_ndim
==
2
else
swizzled_scale
.
reshape
(
B
,
M
,
K
))
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# global scales:
input_scale_2
=
layer
.
input_scale
.
max
().
to
(
torch
.
float32
)
layer
.
input_scale
=
Parameter
(
input_scale_2
,
requires_grad
=
False
)
weight_scale_2
=
layer
.
weight_scale_2
.
max
().
to
(
torch
.
float32
)
layer
.
weight_scale_2
=
Parameter
(
weight_scale_2
,
requires_grad
=
False
)
layer
.
alpha
=
Parameter
(
layer
.
input_scale
*
layer
.
weight_scale_2
,
requires_grad
=
False
)
# Swizzle the weight blockscale.
# contracting dimension is input dimension
# block_size = 16;
assert
(
layer
.
weight_scale
.
shape
[
1
]
%
16
==
0
),
(
"Expected weight_scale.dim(1) to be divisible by 16"
)
assert
(
layer
.
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
"Weight Block scale must be represented as FP8-E4M3"
)
swizzled_weight_scale
=
self
.
swizzle_blockscale
(
layer
.
weight_scale
)
layer
.
weight_scale_swizzled
=
Parameter
(
swizzled_weight_scale
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
output_dtype
=
x
.
dtype
# for input only the contracting dimension has a constraint.
x_m
,
_
=
x
.
shape
w_n
,
_
=
layer
.
weight
.
shape
output_shape
=
[
x_m
,
w_n
]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
s_quant
=
1
/
layer
.
input_scale
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
s_quant
)
# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale
assert
(
x_fp4
.
dtype
==
torch
.
uint8
)
assert
(
layer
.
weight
.
dtype
==
torch
.
uint8
)
assert
(
x_blockscale
.
dtype
==
torch
.
float8_e4m3fn
)
assert
(
layer
.
weight_scale_swizzled
.
dtype
==
torch
.
float8_e4m3fn
)
assert
(
layer
.
alpha
.
dtype
==
torch
.
float32
)
out
=
cutlass_scaled_fp4_mm
(
x_fp4
,
layer
.
weight
,
x_blockscale
,
layer
.
weight_scale_swizzled
,
layer
.
alpha
,
output_dtype
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment