Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
debd6bbf
Unverified
Commit
debd6bbf
authored
Mar 11, 2025
by
Pavani Majety
Committed by
GitHub
Mar 12, 2025
Browse files
[Kernel] Add ModelOpt FP4 Checkpoint Support (#12520)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
5c538c37
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
388 additions
and
30 deletions
+388
-30
csrc/ops.h
csrc/ops.h
+5
-3
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
+6
-0
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
+4
-3
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+4
-0
tests/models/decoder_only/language/test_nvfp4.py
tests/models/decoder_only/language/test_nvfp4.py
+82
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+4
-0
vllm/config.py
vllm/config.py
+1
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+17
-6
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-1
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+262
-16
No files found.
csrc/ops.h
View file @
debd6bbf
...
@@ -160,14 +160,16 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
...
@@ -160,14 +160,16 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
#ifndef USE_ROCM
#ifndef USE_ROCM
bool
cutlass_scaled_mm_supports_fp4
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_block_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
);
torch
::
Tensor
const
&
alpha
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_block_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
...
...
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
View file @
debd6bbf
...
@@ -36,3 +36,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
...
@@ -36,3 +36,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
"be compiled using CUDA 12.8 and target "
"be compiled using CUDA 12.8 and target "
"compute capability 100 or above."
);
"compute capability 100 or above."
);
}
}
bool
cutlass_scaled_mm_supports_fp4
(
int64_t
cuda_device_capability
)
{
int
runtimeVersion
;
cudaRuntimeGetVersion
(
&
runtimeVersion
);
return
cuda_device_capability
>=
100
&&
runtimeVersion
>=
12080
;
}
\ No newline at end of file
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
View file @
debd6bbf
...
@@ -201,10 +201,11 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
...
@@ -201,10 +201,11 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) \
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
TORCH_CHECK(x.is_contiguous(), m, "
:
must be contiguous")
#define CHECK_INPUT(x, st, m) \
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_CONTIGUOUS(x, m); \
...
...
csrc/torch_bindings.cpp
View file @
debd6bbf
...
@@ -434,6 +434,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -434,6 +434,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! output_scale, Tensor input_scale) -> ()"
);
" Tensor! output_scale, Tensor input_scale) -> ()"
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops
.
def
(
"cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"
);
ops
.
impl
(
"cutlass_scaled_mm_supports_fp4"
,
&
cutlass_scaled_mm_supports_fp4
);
#endif
#endif
// Quantized GEMM for GPTQ.
// Quantized GEMM for GPTQ.
...
...
tests/models/decoder_only/language/test_nvfp4.py
0 → 100644
View file @
debd6bbf
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
"""Tests Model Optimizer nvfp4 models against ground truth generation
Note: these tests will only pass on B200
"""
import
os
from
typing
import
List
import
pytest
from
transformers
import
AutoTokenizer
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
LLM
,
SamplingParams
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"true"
MAX_MODEL_LEN
=
1024
MODELS
=
[
"nvidia/Llama-3.3-70B-Instruct-FP4"
]
EXPECTED_STRS_MAP
=
{
"nvidia/Llama-3.3-70B-Instruct-FP4"
:
[
'vLLM (Vectorized Large Language Model) is indeed a high-throughput and memory-efficient inference'
,
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to '
,
'Artificial intelligence (AI) and human intelligence (HI) are two distinct forms of intelligence that process'
,
'A neural network is a type of machine learning model inspired by the structure and function of the human brain'
,
'In the heart of a cutting-edge robotics lab, a team of engineers had been working tirelessly to push'
,
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models, leading'
,
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of'
,
'Here are the translations:
\n\n
* Japanese: (Sasuga no tori ga miwa o ts'
]
}
# This test compares against golden strings for exact match since
# there is no baseline implementation to compare against
# and is unstable w.r.t specifics of the fp4 implementation or
# the hardware being run on.
# Disabled to prevent it from breaking the build
@
pytest
.
mark
.
skip
(
reason
=
"Prevent unstable test based on golden strings from breaking the build "
" and test input model being too large and hanging the system."
)
@
pytest
.
mark
.
quant_model
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"nvfp4"
),
reason
=
"nvfp4 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS
)
def
test_models
(
example_prompts
,
model_name
)
->
None
:
model
=
LLM
(
model
=
model_name
,
max_model_len
=
MAX_MODEL_LEN
,
trust_remote_code
=
True
,
enforce_eager
=
True
,
quantization
=
"nvfp4"
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
formatted_prompts
=
[
tokenizer
.
apply_chat_template
([{
"role"
:
"user"
,
"content"
:
prompt
}],
tokenize
=
False
,
add_generation_prompt
=
True
)
for
prompt
in
example_prompts
]
params
=
SamplingParams
(
max_tokens
=
20
,
temperature
=
0
)
generations
:
List
[
str
]
=
[]
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for
prompt
in
formatted_prompts
:
outputs
=
model
.
generate
(
prompt
,
params
)
generations
.
append
(
outputs
[
0
].
outputs
[
0
].
text
)
del
model
print
(
model_name
,
generations
)
expected_strs
=
EXPECTED_STRS_MAP
[
model_name
]
for
i
in
range
(
len
(
example_prompts
)):
generated_str
=
generations
[
i
]
expected_str
=
expected_strs
[
i
]
assert
expected_str
==
generated_str
,
(
f
"Test
{
i
}
:
\n
Expected:
{
expected_str
!
r
}
\n
vLLM:
{
generated_str
!
r
}
"
)
vllm/_custom_ops.py
View file @
debd6bbf
...
@@ -467,6 +467,10 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
...
@@ -467,6 +467,10 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
# cutlass
# cutlass
def
cutlass_scaled_mm_supports_fp4
(
cuda_device_capability
:
int
)
->
bool
:
return
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp4
(
cuda_device_capability
)
def
cutlass_scaled_fp4_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
def
cutlass_scaled_fp4_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
block_scale_a
:
torch
.
Tensor
,
block_scale_a
:
torch
.
Tensor
,
block_scale_b
:
torch
.
Tensor
,
alpha
:
torch
.
Tensor
,
block_scale_b
:
torch
.
Tensor
,
alpha
:
torch
.
Tensor
,
...
...
vllm/config.py
View file @
debd6bbf
...
@@ -613,7 +613,7 @@ class ModelConfig:
...
@@ -613,7 +613,7 @@ class ModelConfig:
optimized_quantization_methods
=
[
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"experts_int8"
,
"quark"
"compressed-tensors"
,
"experts_int8"
,
"quark"
,
"nvfp4"
]
]
if
self
.
quantization
is
not
None
:
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
vllm/model_executor/layers/linear.py
View file @
debd6bbf
...
@@ -30,12 +30,23 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -30,12 +30,23 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"CompressedTensorsLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"AWQMarlinLinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"AWQLinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"GPTQMarlinLinearMethod"
,
"ModelOptFp8LinearMethod"
,
"IPEXAWQLinearMethod"
,
"IPEXGPTQLinearMethod"
,
"Fp8LinearMethod"
,
"HQQMarlinMethod"
,
"QuarkLinearMethod"
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
,
"IPEXAWQLinearMethod"
,
"IPEXGPTQLinearMethod"
,
"HQQMarlinMethod"
,
"QuarkLinearMethod"
,
"ModelOptNvFp4LinearMethod"
,
]
]
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
debd6bbf
...
@@ -14,6 +14,7 @@ QUANTIZATION_METHODS: List[str] = [
...
@@ -14,6 +14,7 @@ QUANTIZATION_METHODS: List[str] = [
"ptpc_fp8"
,
"ptpc_fp8"
,
"fbgemm_fp8"
,
"fbgemm_fp8"
,
"modelopt"
,
"modelopt"
,
"nvfp4"
,
# The order of gptq methods is important for config.py iteration over
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
# override_quantization_method(..)
"marlin"
,
"marlin"
,
...
@@ -97,7 +98,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -97,7 +98,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from
.hqq_marlin
import
HQQMarlinConfig
from
.hqq_marlin
import
HQQMarlinConfig
from
.ipex_quant
import
IPEXConfig
from
.ipex_quant
import
IPEXConfig
from
.marlin
import
MarlinConfig
from
.marlin
import
MarlinConfig
from
.modelopt
import
ModelOptFp8Config
from
.modelopt
import
ModelOptFp8Config
,
ModelOptNvFp4Config
from
.moe_wna16
import
MoeWNA16Config
from
.moe_wna16
import
MoeWNA16Config
from
.neuron_quant
import
NeuronQuantConfig
from
.neuron_quant
import
NeuronQuantConfig
from
.ptpc_fp8
import
PTPCFp8Config
from
.ptpc_fp8
import
PTPCFp8Config
...
@@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"fp8"
:
Fp8Config
,
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"nvfp4"
:
ModelOptNvFp4Config
,
# The order of gptq methods is important for config.py iteration over
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"marlin"
:
MarlinConfig
,
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
debd6bbf
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm._custom_ops
import
(
cutlass_scaled_fp4_mm
,
cutlass_scaled_mm_supports_fp4
,
scaled_fp4_quant
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
requantize_with_max_scale
)
Fp8LinearOp
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ACTIVATION_SCHEMES
=
[
"static"
]
QUANT_ALGOS
=
[
"FP8"
,
"NVFP4"
]
KV_CACHE_QUANT_ALGOS
=
[
"FP8"
]
class
ModelOptFp8Config
(
QuantizationConfig
):
class
ModelOptFp8Config
(
QuantizationConfig
):
...
@@ -54,12 +61,13 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -54,12 +61,13 @@ class ModelOptFp8Config(QuantizationConfig):
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ModelOptFp8Config"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ModelOptFp8Config"
:
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
quant_method
=
quant_config
[
"quant_algo"
]
is_checkpoint_fp8_serialized
=
(
"FP8"
in
quant_method
)
if
quant_method
not
in
QUANT_ALGOS
:
if
not
is_checkpoint_fp8_serialized
:
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
raise
ValueError
(
"ModelOpt currently only supports static FP8 "
" quantizations in vLLM. Please check the "
"quantization in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
"quant configuration."
)
is_checkpoint_fp8_serialized
=
(
"FP8"
in
quant_method
)
return
cls
(
is_checkpoint_fp8_serialized
)
return
cls
(
is_checkpoint_fp8_serialized
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
...
@@ -72,15 +80,6 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -72,15 +80,6 @@ class ModelOptFp8Config(QuantizationConfig):
return
None
return
None
class
ModelOptFp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
super
().
__init__
(
quant_config
)
class
ModelOptFp8LinearMethod
(
LinearMethodBase
):
class
ModelOptFp8LinearMethod
(
LinearMethodBase
):
"""Linear method for Model Optimizer static quantization.
"""Linear method for Model Optimizer static quantization.
Supports loading FP8 checkpoints with static weight scale and
Supports loading FP8 checkpoints with static weight scale and
...
@@ -162,3 +161,250 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -162,3 +161,250 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
bias
=
bias
)
class
ModelOptNvFp4Config
(
QuantizationConfig
):
"""Config class for ModelOpt FP4."""
def
__init__
(
self
,
is_checkpoint_nvfp4_serialized
:
bool
,
kv_cache_quant_algo
:
str
,
exclude_modules
:
List
[
str
],
group_size
:
int
=
16
,
)
->
None
:
self
.
is_checkpoint_nvfp4_serialized
=
is_checkpoint_nvfp4_serialized
if
is_checkpoint_nvfp4_serialized
:
logger
.
warning
(
"Detected ModelOpt NVFP4 checkpoint. Please note that"
" the format is experimental and could change in future."
)
self
.
group_size
=
group_size
self
.
kv_cache_quant_algo
=
kv_cache_quant_algo
self
.
exclude_modules
=
exclude_modules
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"modelopt_nvfp4"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float8_e4m3fn
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
100
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ModelOptNvFp4Config"
:
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
if
quant_method
not
in
QUANT_ALGOS
:
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
" quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
is_checkpoint_nvfp4_serialized
=
(
"NVFP4"
in
quant_method
)
kv_cache_quant_algo
=
quant_config
[
"kv_cache_quant_algo"
]
group_size
=
quant_config
[
"group_size"
]
exclude_modules
=
quant_config
[
"exclude_modules"
]
if
not
(
group_size
and
kv_cache_quant_algo
and
exclude_modules
):
raise
ValueError
(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json"
)
return
cls
(
is_checkpoint_nvfp4_serialized
,
kv_cache_quant_algo
,
exclude_modules
,
group_size
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
):
return
UnquantizedLinearMethod
()
return
ModelOptNvFp4LinearMethod
(
self
)
elif
isinstance
(
layer
,
Attention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
None
def
cutlass_fp4_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
cutlass_scaled_mm_supports_fp4
(
capability
)
class
ModelOptFp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def
__init__
(
self
,
quant_config
:
Union
[
ModelOptFp8Config
,
ModelOptNvFp4Config
]):
super
().
__init__
(
quant_config
)
class
ModelOptNvFp4LinearMethod
(
LinearMethodBase
):
"""Linear method for Model Optimizer NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
input_scale: torch.float32, scalar ,
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
weight_scale_2: torch.float32, scalar,
Args: quant_config: The ModelOpt quantization config.
"""
def
__init__
(
self
,
quant_config
:
ModelOptNvFp4Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_nvfp4_supported
=
cutlass_fp4_supported
()
if
not
self
.
cutlass_nvfp4_supported
:
raise
ValueError
(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and above."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
if
not
self
.
quant_config
.
is_checkpoint_nvfp4_serialized
:
raise
ValueError
(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
if
(
input_size_per_partition
%
16
!=
0
):
raise
ValueError
(
"Unsupported model when in features size is "
"not multiple of 16"
)
# The nvfp4 weight is still represented as
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_nvfp4_serialized
else
params_dtype
)
# Weight
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
# 2 fp4 items are packed in the input dimension
layer
.
output_size_per_partition
,
layer
.
input_size_per_partition
//
2
,
dtype
=
torch
.
uint8
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
# Input Weight Scale
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
# Global Weight Scale
weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale_2"
,
weight_scale_2
)
# Per Block Weight Scale
weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
dtype
=
weight_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
swizzle_blockscale
(
self
,
scale
:
torch
.
tensor
):
assert
(
scale
.
dtype
==
torch
.
float8_e4m3fn
)
# Pad and blockwise interleave weight_scale
scale_ndim
=
scale
.
ndim
if
scale
.
ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
assert
scale
.
ndim
==
3
B
,
M
,
K
=
scale
.
shape
round_up_multiple
=
lambda
x
,
m
:
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
round_up_multiple
(
M
,
128
)
K_padded
=
round_up_multiple
(
K
,
4
)
padded_scale
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
)
padded_scale
[:
B
,
:
M
,
:
K
]
=
scale
batches
,
rows
,
cols
=
padded_scale
.
shape
assert
rows
%
128
==
0
assert
cols
%
4
==
0
padded_scale
=
padded_scale
.
reshape
(
batches
,
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
swizzled_scale
=
padded_scale
.
permute
((
0
,
1
,
4
,
3
,
2
,
5
))
swizzled_scale
=
swizzled_scale
.
contiguous
().
cuda
()
return
(
swizzled_scale
.
reshape
(
M
,
K
)
if
scale_ndim
==
2
else
swizzled_scale
.
reshape
(
B
,
M
,
K
))
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# global scales:
input_scale_2
=
layer
.
input_scale
.
max
().
to
(
torch
.
float32
)
layer
.
input_scale
=
Parameter
(
input_scale_2
,
requires_grad
=
False
)
weight_scale_2
=
layer
.
weight_scale_2
.
max
().
to
(
torch
.
float32
)
layer
.
weight_scale_2
=
Parameter
(
weight_scale_2
,
requires_grad
=
False
)
layer
.
alpha
=
Parameter
(
layer
.
input_scale
*
layer
.
weight_scale_2
,
requires_grad
=
False
)
# Swizzle the weight blockscale.
# contracting dimension is input dimension
# block_size = 16;
assert
(
layer
.
weight_scale
.
shape
[
1
]
%
16
==
0
),
(
"Expected weight_scale.dim(1) to be divisible by 16"
)
assert
(
layer
.
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
"Weight Block scale must be represented as FP8-E4M3"
)
swizzled_weight_scale
=
self
.
swizzle_blockscale
(
layer
.
weight_scale
)
layer
.
weight_scale_swizzled
=
Parameter
(
swizzled_weight_scale
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
output_dtype
=
x
.
dtype
# for input only the contracting dimension has a constraint.
x_m
,
_
=
x
.
shape
w_n
,
_
=
layer
.
weight
.
shape
output_shape
=
[
x_m
,
w_n
]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
s_quant
=
1
/
layer
.
input_scale
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
s_quant
)
# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale
assert
(
x_fp4
.
dtype
==
torch
.
uint8
)
assert
(
layer
.
weight
.
dtype
==
torch
.
uint8
)
assert
(
x_blockscale
.
dtype
==
torch
.
float8_e4m3fn
)
assert
(
layer
.
weight_scale_swizzled
.
dtype
==
torch
.
float8_e4m3fn
)
assert
(
layer
.
alpha
.
dtype
==
torch
.
float32
)
out
=
cutlass_scaled_fp4_mm
(
x_fp4
,
layer
.
weight
,
x_blockscale
,
layer
.
weight_scale_swizzled
,
layer
.
alpha
,
output_dtype
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment