Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4fc5c23b
Unverified
Commit
4fc5c23b
authored
Feb 12, 2025
by
Kaixi Hou
Committed by
GitHub
Feb 12, 2025
Browse files
[NVIDIA] Support nvfp4 quantization (#12784)
parent
9f9704dc
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
688 additions
and
13 deletions
+688
-13
CMakeLists.txt
CMakeLists.txt
+18
-0
cmake/utils.cmake
cmake/utils.cmake
+13
-5
csrc/cuda_utils.h
csrc/cuda_utils.h
+12
-0
csrc/cuda_utils_kernels.cu
csrc/cuda_utils_kernels.cu
+14
-8
csrc/ops.h
csrc/ops.h
+4
-0
csrc/quantization/fp4/nvfp4_quant_entry.cu
csrc/quantization/fp4/nvfp4_quant_entry.cu
+32
-0
csrc/quantization/fp4/nvfp4_quant_kernels.cu
csrc/quantization/fp4/nvfp4_quant_kernels.cu
+379
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+6
-0
tests/kernels/test_nvfp4_quant.py
tests/kernels/test_nvfp4_quant.py
+149
-0
tests/test_scalartype.py
tests/test_scalartype.py
+1
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+57
-0
vllm/scalar_type.py
vllm/scalar_type.py
+3
-0
No files found.
CMakeLists.txt
View file @
4fc5c23b
...
...
@@ -264,6 +264,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
"csrc/cutlass_extensions/common.cpp"
)
...
...
@@ -377,6 +378,23 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif
()
endif
()
# FP4 Archs and flags
cuda_archs_loose_intersection
(
FP4_ARCHS
"10.0a"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.8 AND FP4_ARCHS
)
set
(
SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
FP4_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
SRCS
}
"
)
list
(
APPEND VLLM_GPU_FLAGS
"-DENABLE_NVFP4=1"
)
message
(
STATUS
"Building NVFP4 for archs:
${
FP4_ARCHS
}
"
)
else
()
message
(
STATUS
"Not building NVFP4 as no compatible archs were found."
)
# clear FP4_ARCHS
set
(
FP4_ARCHS
)
endif
()
#
# Machete kernels
...
...
cmake/utils.cmake
View file @
4fc5c23b
...
...
@@ -257,9 +257,9 @@ endmacro()
# where `<=` is the version comparison operator.
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
# We have special handling for
9
.0a, if
9
.0a is in `SRC_CUDA_ARCHS` and
9
.0 is
# in `TGT_CUDA_ARCHS` then we should remove
9
.0a from `SRC_CUDA_ARCHS` and add
#
9
.0a to the result (and remove
9
.0 from TGT_CUDA_ARCHS).
# We have special handling for
x
.0a, if
x
.0a is in `SRC_CUDA_ARCHS` and
x
.0 is
# in `TGT_CUDA_ARCHS` then we should remove
x
.0a from `SRC_CUDA_ARCHS` and add
#
x
.0a to the result (and remove
x
.0 from TGT_CUDA_ARCHS).
# The result is stored in `OUT_CUDA_ARCHS`.
#
# Example:
...
...
@@ -272,8 +272,8 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
list
(
REMOVE_DUPLICATES SRC_CUDA_ARCHS
)
set
(
TGT_CUDA_ARCHS_
${
TGT_CUDA_ARCHS
}
)
# if
9
.0a is in SRC_CUDA_ARCHS and
9
.0 is in CUDA_ARCHS then we should
# remove
9
.0a from SRC_CUDA_ARCHS and add
9
.0a to _CUDA_ARCHS
# if
x
.0a is in SRC_CUDA_ARCHS and
x
.0 is in CUDA_ARCHS then we should
# remove
x
.0a from SRC_CUDA_ARCHS and add
x
.0a to _CUDA_ARCHS
set
(
_CUDA_ARCHS
)
if
(
"9.0a"
IN_LIST SRC_CUDA_ARCHS
)
list
(
REMOVE_ITEM SRC_CUDA_ARCHS
"9.0a"
)
...
...
@@ -283,6 +283,14 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endif
()
endif
()
if
(
"10.0a"
IN_LIST SRC_CUDA_ARCHS
)
list
(
REMOVE_ITEM SRC_CUDA_ARCHS
"10.0a"
)
if
(
"10.0"
IN_LIST TGT_CUDA_ARCHS
)
list
(
REMOVE_ITEM TGT_CUDA_ARCHS_
"10.0"
)
set
(
_CUDA_ARCHS
"10.0a"
)
endif
()
endif
()
list
(
SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING
)
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
...
...
csrc/cuda_utils.h
View file @
4fc5c23b
#pragma once
#include <stdio.h>
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
#define DEVICE_INLINE __forceinline__ __device__
...
...
@@ -10,6 +12,16 @@
#define HOST_INLINE inline
#endif
#define CUDA_CHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
int64_t
get_device_attribute
(
int64_t
attribute
,
int64_t
device_id
);
int64_t
get_max_shared_memory_per_block_device_attribute
(
int64_t
device_id
);
csrc/cuda_utils_kernels.cu
View file @
4fc5c23b
#include "cuda_utils.h"
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#endif
int64_t
get_device_attribute
(
int64_t
attribute
,
int64_t
device_id
)
{
int
device
,
value
;
if
(
device_id
<
0
)
{
cudaGetDevice
(
&
device
);
}
else
{
device
=
device_id
;
}
cudaDeviceGetAttribute
(
&
value
,
static_cast
<
cudaDeviceAttr
>
(
attribute
),
device
);
// Return the cached value on subsequent calls
static
int
value
=
[
=
]()
{
int
device
=
static_cast
<
int
>
(
device_id
);
if
(
device
<
0
)
{
CUDA_CHECK
(
cudaGetDevice
(
&
device
));
}
int
value
;
CUDA_CHECK
(
cudaDeviceGetAttribute
(
&
value
,
static_cast
<
cudaDeviceAttr
>
(
attribute
),
device
));
return
static_cast
<
int
>
(
value
);
}();
return
value
;
}
...
...
csrc/ops.h
View file @
4fc5c23b
...
...
@@ -195,6 +195,10 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
);
...
...
csrc/quantization/fp4/nvfp4_quant_entry.cu
0 → 100644
View file @
4fc5c23b
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
scaled_fp4_quant_sm100a
(
torch
::
Tensor
const
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
output_sf
,
torch
::
Tensor
const
&
input_sf
);
#endif
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_quant_sm100a
(
output
,
input
,
output_sf
,
input_sf
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
}
csrc/quantization/fp4/nvfp4_quant_kernels.cu
0 → 100644
View file @
4fc5c23b
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "cuda_utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float
SFValue
=
SFScaleVal
*
(
vecMax
*
reciprocal_approximate_ftz
(
6.0
f
));
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
if
constexpr
(
UE8M0_SF
)
{
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t
tmp
=
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
>>
23
;
fp8SFVal
=
tmp
&
0xff
;
// Convert back to fp32.
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
=
tmp
<<
23
;
}
else
{
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3
tmp
=
__nv_fp8_e4m3
(
SFValue
);
reinterpret_cast
<
__nv_fp8_e4m3
&>
(
fp8SFVal
)
=
tmp
;
// Convert back to fp32.
SFValue
=
float
(
tmp
);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
#else
return
0
;
#endif
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
0
];
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx
,
colIdx
,
numCols
,
SFout
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
}
#endif
}
template
<
typename
T
>
void
invokeFP4Quantization
(
int
m
,
int
n
,
T
const
*
input
,
float
const
*
SFScale
,
int64_t
*
output
,
int32_t
*
SFOuput
,
bool
useUE8M0
,
int
multiProcessorCount
,
cudaStream_t
stream
)
{
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
512
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
2048
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m
),
multiProcessorCount
*
numBlocksPerSM
));
// Launch the cvt kernel.
if
(
useUE8M0
)
{
cvt_fp16_to_fp4
<
T
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input
,
SFScale
,
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
SFOuput
));
}
else
{
cvt_fp16_to_fp4
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input
,
SFScale
,
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
SFOuput
));
}
}
// Instantiate the function.
template
void
invokeFP4Quantization
(
int
m
,
int
n
,
half
const
*
input
,
float
const
*
SFScale
,
int64_t
*
output
,
int32_t
*
SFOuput
,
bool
useUE8M0
,
int
multiProcessorCount
,
cudaStream_t
stream
);
template
void
invokeFP4Quantization
(
int
m
,
int
n
,
__nv_bfloat16
const
*
input
,
float
const
*
SFScale
,
int64_t
*
output
,
int32_t
*
SFOuput
,
bool
useUE8M0
,
int
multiProcessorCount
,
cudaStream_t
stream
);
void
scaled_fp4_quant_sm100a
(
torch
::
Tensor
const
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
);
TORCH_CHECK
(
n
%
16
==
0
,
"The N dimension must be multiple of 16."
);
int
multiProcessorCount
=
get_device_attribute
(
cudaDevAttrMultiProcessorCount
,
-
1
);
auto
input_sf_ptr
=
static_cast
<
float
const
*>
(
input_sf
.
data_ptr
());
auto
sf_out
=
static_cast
<
int32_t
*>
(
output_sf
.
data_ptr
());
auto
output_ptr
=
static_cast
<
int64_t
*>
(
output
.
data_ptr
());
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
auto
stream
=
at
::
cuda
::
getStreamFromPool
(
false
,
input
.
get_device
());
if
(
stream
==
nullptr
)
{
std
::
cerr
<<
"Warning: Null CUDA stream"
<<
std
::
endl
;
}
// We don't support e8m0 scales at this moment.
bool
useUE8M0
=
false
;
switch
(
input
.
scalar_type
())
{
case
torch
::
kHalf
:
{
auto
input_ptr
=
reinterpret_cast
<
half
const
*>
(
input
.
data_ptr
());
invokeFP4Quantization
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
output_ptr
,
sf_out
,
useUE8M0
,
multiProcessorCount
,
stream
);
break
;
}
case
torch
::
kBFloat16
:
{
auto
input_ptr
=
reinterpret_cast
<
__nv_bfloat16
const
*>
(
input
.
data_ptr
());
invokeFP4Quantization
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
output_ptr
,
sf_out
,
useUE8M0
,
multiProcessorCount
,
stream
);
break
;
}
default:
{
std
::
cerr
<<
"Observing: "
<<
input
.
scalar_type
()
<<
" for the input datatype which is invalid"
;
throw
std
::
runtime_error
(
"Unsupported input data type for quantize_to_fp4."
);
}
}
}
csrc/torch_bindings.cpp
View file @
4fc5c23b
...
...
@@ -423,6 +423,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"dynamic_per_token_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_per_token_scaled_fp8_quant
);
// Compute NVFP4 block quantized tensor.
ops
.
def
(
"scaled_fp4_quant(Tensor! output, Tensor input,"
" Tensor! output_scale, Tensor input_scale) -> ()"
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
...
...
tests/kernels/test_nvfp4_quant.py
0 → 100644
View file @
4fc5c23b
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
allow_module_level
=
True
)
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
SHAPES
=
[(
128
,
64
),
(
128
,
128
),
(
256
,
64
),
(
256
,
128
)]
PAD_SHAPES
=
[(
90
,
64
),
(
150
,
64
),
(
128
,
48
),
(
128
,
80
),
(
150
,
80
),
(
90
,
48
),
(
90
,
128
),
(
150
,
128
),
(
150
,
48
),
(
90
,
80
)]
SEEDS
=
[
42
]
CUDA_DEVICES
=
[
'cuda:0'
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1fn
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
# E2M1 to float
# 0111 -> 6
# 0110 -> 4
# 0101 -> 3
# 0100 -> 2
# 0011 -> 1.5
# 0010 -> 1
# 0001 -> 0.5
# 0000 -> 0
E2M1_TO_FLOAT32
=
[
0.
,
0.5
,
1.
,
1.5
,
2.
,
3.
,
4.
,
6.
,
0.
,
-
0.5
,
-
1.
,
-
1.5
,
-
2.
,
-
3.
,
-
4.
,
-
6.
]
BLOCK_SIZE
=
16
def
cast_from_fp4
(
x
,
m
,
n
):
# The fp4 values are packed in uint8 as [v_1st | v_2nd]
v_2nd
=
x
&
0xF
v_1st
=
(
x
>>
4
)
&
0xF
c
=
torch
.
stack
((
v_2nd
,
v_1st
),
dim
=-
1
)
out
=
torch
.
tensor
([
E2M1_TO_FLOAT32
[
x
]
for
x
in
c
.
flatten
()])
out
=
out
.
reshape
(
m
,
n
).
to
(
torch
.
float32
)
return
out
def
cast_to_fp4
(
x
):
sign
=
torch
.
sign
(
x
)
x
=
torch
.
abs
(
x
)
x
[(
x
>=
0.0
)
&
(
x
<=
0.25
)]
=
0.0
x
[(
x
>
0.25
)
&
(
x
<
0.75
)]
=
0.5
x
[(
x
>=
0.75
)
&
(
x
<=
1.25
)]
=
1.0
x
[(
x
>
1.25
)
&
(
x
<
1.75
)]
=
1.5
x
[(
x
>=
1.75
)
&
(
x
<=
2.5
)]
=
2.0
x
[(
x
>
2.5
)
&
(
x
<
3.5
)]
=
3.0
x
[(
x
>=
3.5
)
&
(
x
<=
5.0
)]
=
4.0
x
[
x
>
5.0
]
=
6.0
return
x
*
sign
def
get_reciprocal
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
return
torch
.
where
(
x
==
0
,
torch
.
tensor
(
0.0
,
dtype
=
x
.
dtype
),
1.0
/
x
)
elif
isinstance
(
x
,
(
float
,
int
)):
return
0.0
if
x
==
0
else
1.0
/
x
else
:
raise
TypeError
(
"Input must be a float, int, or a torch.Tensor."
)
def
ref_nvfp4_quant
(
x
,
global_scale
):
assert
global_scale
.
dtype
==
torch
.
float32
assert
x
.
ndim
==
2
m
,
n
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
m
,
n
//
BLOCK_SIZE
,
BLOCK_SIZE
))
vec_max
=
torch
.
max
(
torch
.
abs
(
x
),
dim
=-
1
,
keepdim
=
True
)[
0
].
to
(
torch
.
float32
)
scale
=
global_scale
*
(
vec_max
*
get_reciprocal
(
FLOAT4_E2M1_MAX
))
scale
=
scale
.
to
(
torch
.
float8_e4m3fn
).
to
(
torch
.
float32
)
output_scale
=
get_reciprocal
(
scale
*
get_reciprocal
(
global_scale
))
scaled_x
=
x
.
to
(
torch
.
float32
)
*
output_scale
clipped_x
=
torch
.
clamp
(
scaled_x
,
-
6.0
,
6.0
).
reshape
(
m
,
n
)
return
cast_to_fp4
(
clipped_x
),
scale
.
squeeze
(
-
1
)
def
recover_swizzled_scales
(
scale
,
m
,
n
):
round_up
=
lambda
x
,
y
:
(
x
+
y
-
1
)
//
y
*
y
rounded_m
=
round_up
(
m
,
128
)
scale_n
=
n
//
BLOCK_SIZE
rounded_n
=
round_up
(
scale_n
,
4
)
# Recover the swizzled scaling factor to linear layout
tmp
=
torch
.
reshape
(
scale
,
(
1
,
rounded_m
//
128
,
rounded_n
//
4
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
result
=
torch
.
reshape
(
tmp
,
(
rounded_m
,
rounded_n
)).
to
(
torch
.
float32
)
return
result
[:
m
,
:
scale_n
]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
SHAPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_quantize_to_fp4
(
dtype
:
torch
.
dtype
,
shape
:
tuple
[
int
,
int
],
seed
:
int
,
device
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
m
,
n
=
shape
x
=
torch
.
randn
((
m
,
n
),
dtype
=
dtype
)
tensor_amax
=
torch
.
abs
(
x
).
max
().
to
(
torch
.
float32
)
global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
tensor_amax
out_ref
,
scale_ref
=
ref_nvfp4_quant
(
x
,
global_scale
)
out
,
out_scale
=
ops
.
scaled_fp4_quant
(
x
,
global_scale
)
scale_ans
=
recover_swizzled_scales
(
out_scale
,
m
,
n
)
out_ans
=
cast_from_fp4
(
out
,
m
,
n
)
torch
.
testing
.
assert_close
(
out_ans
,
out_ref
)
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
@
pytest
.
mark
.
parametrize
(
"pad_shape"
,
PAD_SHAPES
)
@
torch
.
inference_mode
()
def
test_quantize_to_fp4_padded
(
pad_shape
:
tuple
[
int
,
int
])
->
None
:
dtype
=
torch
.
float16
current_platform
.
seed_everything
(
42
)
torch
.
set_default_device
(
'cuda:0'
)
m
,
n
=
pad_shape
x
=
torch
.
randn
((
m
,
n
),
dtype
=
dtype
)
tensor_amax
=
torch
.
abs
(
x
).
max
().
to
(
torch
.
float32
)
global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
tensor_amax
out_ref
,
scale_ref
=
ref_nvfp4_quant
(
x
,
global_scale
)
out
,
out_scale
=
ops
.
scaled_fp4_quant
(
x
,
global_scale
)
scale_ans
=
recover_swizzled_scales
(
out_scale
,
m
,
n
)
out_ans
=
cast_from_fp4
(
out
,
m
,
n
)
torch
.
testing
.
assert_close
(
out_ans
,
out_ref
)
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
tests/test_scalartype.py
View file @
4fc5c23b
...
...
@@ -11,6 +11,7 @@ from vllm.scalar_type import scalar_types
(
0
,
15
,
scalar_types
.
uint4
),
(
-
8
,
7
,
scalar_types
.
uint4b8
),
(
-
128
,
127
,
scalar_types
.
uint8b128
),
(
-
6.
,
6.
,
scalar_types
.
float4_e2m1fn
),
(
-
28.
,
28.
,
scalar_types
.
float6_e3m2f
),
(
torch
.
int8
,
scalar_types
.
int8
),
(
torch
.
uint8
,
scalar_types
.
uint8
),
...
...
vllm/_custom_ops.py
View file @
4fc5c23b
...
...
@@ -765,6 +765,63 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
return
torch
.
ops
.
_C
.
permute_cols
(
a
,
perm
)
# fp4
def
scaled_fp4_quant
(
input
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
This function quantizes the last dimension of the given tensor `input`. For
every 16 consecutive elements, a single dynamically computed scaling factor
is shared. This scaling factor is quantized using the `input_global_scale`
and is stored in a swizzled layout (see
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
Args:
input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in the sizzled layout.
"""
assert
input
.
ndim
>=
1
,
(
f
'input.ndim needs to be >= 1, but got
{
input
.
ndim
}
.'
)
other_dims
=
1
if
input
.
ndim
==
1
else
-
1
input
=
input
.
reshape
(
other_dims
,
input
.
shape
[
-
1
])
m
,
n
=
input
.
shape
block_size
=
16
device
=
input
.
device
assert
n
%
block_size
==
0
,
(
f
'last dim has to be multiple of 16, but got
{
n
}
.'
)
assert
input
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
(
f
'input.dtype needs to be fp16 or bf16 but got
{
input
.
dtype
}
.'
)
# Two fp4 values will be packed into an uint8.
output
=
torch
.
empty
((
m
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
# We use the rounded values to store the swizzled values. Due to the
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up
=
lambda
x
,
y
:
(
x
+
y
-
1
)
//
y
*
y
rounded_m
=
round_up
(
m
,
128
)
scale_n
=
n
//
block_size
rounded_n
=
round_up
(
scale_n
,
4
)
output_scale
=
torch
.
empty
((
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
_C
.
scaled_fp4_quant
(
output
,
input
,
output_scale
,
input_global_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scale
# fp8
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
...
...
vllm/scalar_type.py
View file @
4fc5c23b
...
...
@@ -321,6 +321,9 @@ class scalar_types:
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f
=
ScalarType
.
float_
(
3
,
2
,
True
,
NanRepr
.
NONE
)
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
float4_e2m1fn
=
ScalarType
.
float_
(
2
,
1
,
True
,
NanRepr
.
NONE
)
# "gptq" types
uint2b2
=
ScalarType
.
uint
(
2
,
2
)
uint3b4
=
ScalarType
.
uint
(
3
,
4
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment