Commit 01caaf29 authored by helloyongyang's avatar helloyongyang
Browse files

Add lightx2v_kernel for nvfp4

parent ea618db2
......@@ -22,6 +22,5 @@
*.pid
*.ipynb*
*.mp4
# just4dev
devscripts/
build/
dist/
cmake_minimum_required(VERSION 3.22 FATAL_ERROR)
project(lightx2v-kernel LANGUAGES CXX CUDA)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Python
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
# CXX
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
# CUDA
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON)
# Torch
find_package(Torch REQUIRED)
# clean Torch Flag
clear_cuda_arches(CMAKE_FLAG)
# cutlass
if(CUTLASS_PATH)
set(repo-cutlass_SOURCE_DIR ${CUTLASS_PATH})
message(STATUS "Using local CUTLASS from: ${CUTLASS_PATH}")
else()
message(STATUS "Start to git clone CUTLASS from GitHub...")
include(FetchContent)
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG b244379d9b15574e07b73b814b88bd2233f0b3ce
GIT_SHALLOW OFF
)
FetchContent_MakeAvailable(repo-cutlass)
message(STATUS "Using CUTLASS from ${repo-cutlass_SOURCE_DIR}")
endif()
# ccache option
option(ENABLE_CCACHE "Whether to use ccache" ON)
find_program(CCACHE_FOUND ccache)
if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR})
message(STATUS "Building with CCACHE enabled")
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache")
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache")
endif()
include_directories(
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/csrc
${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
)
set(LIGHTX2V_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=lightx2v-kernel"
"-O3"
"-Xcompiler"
"-fPIC"
"-std=c++17"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
"-DCUTLASS_VERSIONS_GENERATED"
"-DCUTLASS_TEST_LEVEL=0"
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--expt-extended-lambda"
"--threads=32"
# Suppress warnings
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)
list(APPEND LIGHTX2V_KERNEL_CUDA_FLAGS
# "-gencode=arch=compute_90,code=sm_90"
# "-gencode=arch=compute_90a,code=sm_90a"
# "-gencode=arch=compute_100,code=sm_100"
# "-gencode=arch=compute_100a,code=sm_100a"
# "-gencode=arch=compute_120,code=sm_120"
"-gencode=arch=compute_120a,code=sm_120a"
)
set(SOURCES
"csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/nvfp4_quant_kernels_sm120.cu"
"csrc/common_extension.cc"
)
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
message(STATUS "LIGHTX2V_KERNEL_CUDA_FLAGS: ${LIGHTX2V_KERNEL_CUDA_FLAGS}")
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${LIGHTX2V_KERNEL_CUDA_FLAGS}>)
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
install(TARGETS common_ops LIBRARY DESTINATION lightx2v_kernel)
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# lightx2v_kernel
### Preparation
```
# Install torch, at least version 2.7
pip install scikit_build_core uv
```
### Build whl
```
MAX_JOBS=$(nproc) && CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) \
uv build --wheel \
-Cbuild-dir=build . \
--verbose \
--color=always \
--no-build-isolation
```
During the above build process, the cutlass source code will be downloaded automatically. If you have already downloaded the source code, you can specify the local cutlass path:
```
MAX_JOBS=$(nproc) && CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) \
uv build --wheel \
-Cbuild-dir=build . \
-Ccmake.define.CUTLASS_PATH=/path/to/cutlass \
--verbose \
--color=always \
--no-build-isolation
```
### Install whl
```
pip install dist/*whl --force-reinstall --no-deps
```
### Test
##### cos and speed test, mm without bias
```
python test/test_bench2.py
```
##### cos and speed test, mm with bias
```
python test/test_bench3_bias.py
```
##### Bandwidth utilization test for quant
```
python test/test_quant_mem_utils.py
```
##### tflops test for mm
```
python test/test_mm_tflops.py
```
# Adapt from: https://github.com/neuralmagic/vllm-flash-attention/blob/main/cmake/utils.cmake
#
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
# `CUDA_ARCH_FLAGS`.
#
# Example:
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
# clear_cuda_arches(CUDA_ARCH_FLAGS)
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
# CMAKE_CUDA_FLAGS="-Wall"
#
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
endmacro()
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/all.h>
#include <torch/library.h>
#include "lightx2v_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
m.def(
"cutlass_scaled_fp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
"alpha, Tensor? bias) -> ()");
m.impl("cutlass_scaled_fp4_mm_sm120", torch::kCUDA, &cutlass_scaled_fp4_mm_sm120);
m.def(
"scaled_fp4_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant_sm120", torch::kCUDA, &scaled_fp4_quant_sm120);
}
REGISTER_EXTENSION(common_ops)
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include "utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
// PTX instructions used here requires sm100a.
// #if CUDA_VERSION >= 12080
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]),
"f"(array[1]),
"f"(array[2]),
"f"(array[3]),
"f"(array[4]),
"f"(array[5]),
"f"(array[6]),
"f"(array[7]));
return val;
// #else
// return 0;
// #endif
// #endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
// PTX instructions used here requires sm100a.
// #if CUDA_VERSION >= 12080
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x),
"f"(array[0].y),
"f"(array[1].x),
"f"(array[1].y),
"f"(array[2].x),
"f"(array[2].y),
"f"(array[3].x),
"f"(array[3].y));
return val;
// #else
// return 0;
// #endif
// #endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride +
innerMIdx * innerMStride + innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
// #endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * 0.16666666666666666f);
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
__nv_fp8_e8m0 tmp;
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
SFValue = static_cast<float>(tmp);
fp8SFVal = tmp.__x;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
fp8SFVal = tmp.__x;
SFValue = static_cast<float>(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
// float outputScale =
// SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;
float outputScale =
SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz(SFValue) : 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
// #else
// return 0;
// #endif
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(256, 6) cvt_fp16_to_fp4(
// #else
// cvt_fp16_to_fp4(
// #endif
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(rowIdx, colIdx, numCols, SFout);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
}
// #endif
}
template <typename T>
void invokeFP4Quantization(
int m,
int n,
T const* input,
float const* SFScale,
int64_t* output,
int32_t* SFOuput,
bool useUE8M0,
int multiProcessorCount,
cudaStream_t stream) {
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 256));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 1536 / block.x;
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
// Launch the cvt kernel.
if (useUE8M0) {
cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput));
} else {
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput));
}
}
// Instantiate the function.
template void invokeFP4Quantization(
int m,
int n,
half const* input,
float const* SFScale,
int64_t* output,
int32_t* SFOuput,
bool useUE8M0,
int multiProcessorCount,
cudaStream_t stream);
template void invokeFP4Quantization(
int m,
int n,
__nv_bfloat16 const* input,
float const* SFScale,
int64_t* output,
int32_t* SFOuput,
bool useUE8M0,
int multiProcessorCount,
cudaStream_t stream);
inline int getMultiProcessorCount() {
static int multi_processor_count = []() {
int device_id = 0;
int count = 0;
// Get the current CUDA device ID
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
// Get the number of multiprocessors for the current device
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id));
return count; // Initialize the static variable
}();
return multi_processor_count; // Return the cached value on subsequent calls
}
void scaled_fp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1);
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
int multiProcessorCount = getMultiProcessorCount();
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
// We don't support e8m0 scales at this moment.
bool useUE8M0 = false;
switch (input.scalar_type()) {
case torch::kHalf: {
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream);
break;
}
case torch::kBFloat16: {
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream);
break;
}
default: {
std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid";
throw std::runtime_error("Unsupported input data type for quantize_to_fp4.");
}
}
}
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
using namespace cute;
struct Fp4GemmSm120 {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
static constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
static constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
};
// Populates a Gemm::Arguments structure from the given commandline options
typename Fp4GemmSm120::Gemm::Arguments args_from_options(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
at::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias,
int64_t M,
int64_t N,
int64_t K) {
using Sm1xxBlkScaledConfig = typename Fp4GemmSm120::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
int m = static_cast<int>(M);
int n = static_cast<int>(N);
int k = static_cast<int>(K);
auto stride_A = cutlass::make_cute_packed_stride(Fp4GemmSm120::StrideA{}, {m, k, 1});
auto stride_B = cutlass::make_cute_packed_stride(Fp4GemmSm120::StrideB{}, {n, k, 1});
auto stride_D = cutlass::make_cute_packed_stride(Fp4GemmSm120::StrideD{}, {m, n, 1});
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
if (bias){
auto stride_bias = cutlass::make_cute_packed_stride(Fp4GemmSm120::StrideC{}, {});
typename Fp4GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{// Mainloop arguments
static_cast<Fp4GemmSm120::Gemm::ElementA const*>(A.data_ptr()),
stride_A,
static_cast<Fp4GemmSm120::Gemm::ElementB const*>(B.data_ptr()),
stride_B,
static_cast<cutlass::float_ue4m3_t const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<cutlass::float_ue4m3_t const*>(B_sf.data_ptr()),
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Fp4GemmSm120::Gemm::ElementC const*>(bias->data_ptr()),
stride_bias,
static_cast<Fp4GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
return arguments;
} else {
typename Fp4GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{// Mainloop arguments
static_cast<Fp4GemmSm120::Gemm::ElementA const*>(A.data_ptr()),
stride_A,
static_cast<Fp4GemmSm120::Gemm::ElementB const*>(B.data_ptr()),
stride_B,
static_cast<cutlass::float_ue4m3_t const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<cutlass::float_ue4m3_t const*>(B_sf.data_ptr()),
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Fp4GemmSm120::Gemm::ElementC const*>(D.data_ptr()),
stride_D,
static_cast<Fp4GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
return arguments;
}
}
void runGemm(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
at::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias,
int64_t m,
int64_t n,
int64_t k,
cudaStream_t stream) {
typename Fp4GemmSm120::Gemm gemm;
auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, bias, m, n, k);
size_t workspace_size = Fp4GemmSm120::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
void cutlass_scaled_fp4_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias) {
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(
A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (",
A.sizes()[0],
"x",
A.sizes()[1],
" and ",
B.sizes()[0],
"x",
B.sizes()[1],
")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
constexpr int alignment = 32;
TORCH_CHECK(
k % alignment == 0,
"Expected k to be divisible by ",
alignment,
", but got a shape: (",
A.sizes()[0],
"x",
A.sizes()[1],
"), k: ",
k,
".");
TORCH_CHECK(
n % alignment == 0,
"Expected n to be divisible by ",
alignment,
", but got b shape: (",
B.sizes()[0],
"x",
B.sizes()[1],
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
int rounded_n = round_up(n, 128);
// Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an
// integer.
int rounded_k = round_up(k / 16, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(
A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0],
"x",
A_sf.sizes()[1],
" and ",
B_sf.sizes()[0],
"x",
B_sf.sizes()[1],
")");
TORCH_CHECK(
A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (",
rounded_m,
"x",
rounded_k,
"), but got a shape (",
A_sf.sizes()[0],
"x",
A_sf.sizes()[1],
")");
TORCH_CHECK(
B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (",
rounded_n,
"x",
rounded_k,
"), but got a shape (",
B_sf.sizes()[0],
"x",
B_sf.sizes()[1],
")");
auto out_dtype = D.dtype();
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
runGemm(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream);
}
# nvfp4 Quantization Basics
### Data Format
The calculation method for fp is:
`ans = (-1)^s * 2^(p-b) * (1 + d1/2 + d2/4 + d3/8 + ...)`
Where `b = 2^(e-1) - 1`, p represents the value of the exponent bits, d1, d2, d3 represent the values of the mantissa bits
For fp4, the format is E2M1, and the above formula is simplified to:
`b = 2^(e-1) - 1 = 2^(2-1) - 1 = 1`
`ans = (-1)^s * 2^(p-1) * (1 + d1/2)`
Example: 0101
`s=0, p=(10)=2, d1=1`
`ans = 2^0 * 2^(2-1) * (1 + 1/2) = 3`
In normal fp data format, some data represents inf and nan, with a maximum representation of ±3. Specialized for nvfp4, inf and nan are removed, allowing a maximum representation of ±6.
Specifically, 0000 represents +0, 1000 represents -0, 0001 represents 0.5, and 1001 represents -0.5.
In summary:
| E2M1 | 0000 | 0001 | 0010 | 0011 | 0100 | 0101 | 0110 | 0111 | 1000 | 1001 | 1010 | 1011 | 1100 | 1101 | 1110 | 1111 |
|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|
| ans | +0 | 0.5 | 1.0 | 1.5 | 2.0 | 3.0 | 4.0 | 6.0 | +0 | -0.5 | -1.0 | -1.5 | -2.0 | -3.0 | -4.0 | -6.0 |
### Quantization Process
**Both weight and activation use per-group quantization, with a group size of 16, and quantization scales are stored in fp8(e4m3) format**
Since the quantization scale needs to be stored in fp8, the scale also needs to be rescaled, so the fp4 quantization process differs somewhat from the common w8a8-int8 process.
The quantization process is as follows:
Given a set of numbers, denoted as `X`
#### Calculate scale
`scale1 = max(abs(Xg)) / 6.0`
Where Xg represents a group of numbers, and 6.0 represents the maximum value of nvfp4
#### Quantize scale
`global_scale = 6.0 * 448.0 / max(abs(X))`
`scale2 = global_scale * scale1`
That is `scale2 = 6.0 * 448.0 / max(abs(X)) * max(abs(Xg)) / 6.0`
That is `scale2 = max(abs(Xg)) / max(abs(X)) * 448.0`
At this point, scale2 is rescaled to the range of fp8(e4m3), then scale2 is quantized to fp8
`scale2_fp8 = quant_fp8(scale2)`
`scale2_fp8` serves as the final quantization scale parameter required for matrix multiplication
#### Quantize X
`scale2_fp32 = cvt2fp32(scale2_fp8)`
`Xquant = quant_fp4(X * global_scale / scale2_fp32)`
Then `Xquant ≈ quant_fp4(X / scale1)`
#### fp4 Matrix Multiplication
`ans = Aquant * Bquant * Ascale2 * Bscale2 / Aglobal_scale / Bglobal_scale`
That is `ans ≈ Aquant * Bquant * Aglobal_scale * Ascale1 * Bglobal_scale * Bscale1 / Aglobal_scale / Bglobal_scale`
That is `ans ≈ Aquant * Bquant * Ascale1 * Bscale1`
# nvfp4量化基础
### 数据格式
fp的计算方式是:
`ans = (-1)^s * 2^(p-b) * (1 + d1/2 + d2/4 + d3/8 + ...)`
其中,`b = 2^(e-1) - 1`,p表示指数位的值,d1, d2, d3表示尾数位的值
对于fp4,格式是E2M1,上述的式子简化为:
`b = 2^(e-1) - 1 = 2^(2-1) - 1 = 1`
`ans = (-1)^s * 2^(p-1) * (1 + d1/2)`
举例:0101
`s=0, p=(10)=2, d1=1`
`ans = 2^0 * 2^(2-1) * (1 + 1/2) = 3`
正常的fp数据格式,还会有部分数据表示inf和nan,最大只能表示±3,特化到nvfp4,取消了inf和nan,最大可以表示±6
特殊的,其中0000表示+0,1000表示-0,0001表示0.5,1001表示-0.5
综上:
| E2M1 | 0000 | 0001 | 0010 | 0011 | 0100 | 0101 | 0110 | 0111 | 1000 | 1001 | 1010 | 1011 | 1100 | 1101 | 1110 | 1111 |
|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|
| ans | +0 | 0.5 | 1.0 | 1.5 | 2.0 | 3.0 | 4.0 | 6.0 | +0 | -0.5 | -1.0 | -1.5 | -2.0 | -3.0 | -4.0 | -6.0 |
### 量化过程
**weight和act都是per group量化,group size都是16,量化scale以fp8(e4m3)格式存储**
由于量化scale要用fp8存储,需要对scale也进行放缩,所以fp4量化的过程和常见的w8a8-int8过程,有一些不同
量化过程如下:
给定一组数,记作`X`
#### 计算scale
`scale1 = max(abs(Xg)) / 6.0`
其中Xg表示一个group的数,6.0表示nvfp4的最大值
#### 量化scale
`global_scale = 6.0 * 448.0 / max(abs(X))`
`scale2 = global_scale * scale1`
`scale2 = 6.0 * 448.0 / max(abs(X)) * max(abs(Xg)) / 6.0`
`scale2 = max(abs(Xg)) / max(abs(X)) * 448.0`
此时scale2被放缩到fp8(e4m3)的范围,然后对scale2进行量化到fp8
`scale2_fp8 = quant_fp8(scale2)`
`scale2_fp8`则作为最终的矩阵乘法所需的量化scale参数
#### 量化X
`scale2_fp32 = cvt2fp32(scale2_fp8)`
`Xquant = quant_fp4(X * global_scale / scale2_fp32)`
`Xquant ≈ quant_fp4(X / scale1)`
#### fp4矩阵乘法
`ans = Aquant * Bquant * Ascale2 * Bscale2 / Aglobal_scale / Bglobal_scale`
`ans ≈ Aquant * Bquant * Aglobal_scale * Ascale1 * Bglobal_scale * Bscale1 / Aglobal_scale / Bglobal_scale`
`ans ≈ Aquant * Bquant * Ascale1 * Bscale1`
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <Python.h>
#include <torch/library.h>
#include <torch/torch.h>
#include <tuple>
#include <vector>
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
/*
* From csrc/gemm
*/
void cutlass_scaled_fp4_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias);
void scaled_fp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf);
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <ATen/Tensor.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <sstream>
#ifndef USE_ROCM
// Adapt from FlashInfer
#ifdef FLASHINFER_ENABLE_F16
#define _DISPATCH_CASE_F16(c_type, ...) \
case at::ScalarType::Half: { \
using c_type = nv_half; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_F16(c_type, ...)
#endif
#ifdef FLASHINFER_ENABLE_BF16
#define _DISPATCH_CASE_BF16(c_type, ...) \
case at::ScalarType::BFloat16: { \
using c_type = nv_bfloat16; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_BF16(c_type, ...)
#endif
#ifdef FLASHINFER_ENABLE_FP8_E4M3
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
case at::ScalarType::Float8_e4m3fn: { \
using c_type = __nv_fp8_e4m3; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)
#endif
#ifdef FLASHINFER_ENABLE_FP8_E5M2
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
case at::ScalarType::Float8_e5m2: { \
using c_type = __nv_fp8_e5m2; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)
#endif
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_SWITCH(var_name, cond, ...) \
[&]() -> bool { \
switch (cond) { \
__VA_ARGS__ \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \
[&]() -> bool { \
switch (pack_u16(cond1, cond2)) { \
__VA_ARGS__ \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \
<< int(cond2) << ")"; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_CASE(case_expr, case_var, ...) \
case case_expr: { \
constexpr auto case_var = case_expr; \
return __VA_ARGS__(); \
}
#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \
case pack_u16(case_expr1, case_expr2): { \
constexpr auto case_var1 = case_expr1; \
constexpr auto case_var2 = case_expr2; \
return __VA_ARGS__(); \
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
}
}
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
return (uint32_t(a) << 16) | uint32_t(b);
}
#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \
TORCH_CHECK( \
num_qo_heads % num_kv_heads == 0, \
"num_qo_heads(", \
num_qo_heads, \
") must be divisible by num_kv_heads(", \
num_kv_heads, \
")")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CUDA(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
inline bool is_float8_tensor(const at::Tensor& tensor) {
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2;
}
#endif
struct cuda_error : public std::runtime_error {
/**
* @brief Constructs a `cuda_error` object with the given `message`.
*
* @param message The error char array used to construct `cuda_error`
*/
cuda_error(const char* message) : std::runtime_error(message) {}
/**
* @brief Constructs a `cuda_error` object with the given `message` string.
*
* @param message The `std::string` used to construct `cuda_error`
*/
cuda_error(std::string const& message) : cuda_error{message.c_str()} {}
};
#define CHECK_CUDA_SUCCESS(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
std::stringstream _message; \
auto s = cudaGetErrorString(e); \
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
throw cuda_error(_message.str()); \
} \
} while (0)
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
CHECK_IS_CUDA(x); \
CHECK_IS_CONTIGUOUS(x)
inline int getSMVersion() {
int device{-1};
CHECK_CUDA_SUCCESS(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width))
#else
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
#endif
#ifndef USE_ROCM
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Float: { \
using c_type = float; \
return __VA_ARGS__(); \
} \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#endif
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define WARP_SIZE 32
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else
#include <c10/util/Float8_e4m3fnuz.h>
using FP8_TYPE = c10::Float8_e4m3fnuz;
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
#ifndef USE_ROCM
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
__device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1));
return max_value;
}
__device__ __forceinline__ float blockReduceMax(float max_value) {
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
max_value = warpReduceMax(max_value);
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) max_value = warpReduceMax(max_value);
return max_value;
}
#endif
// Pads to a multiple of `alignment` rows.
inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) {
int64_t rows = tensor.size(0);
int64_t cols = tensor.size(1);
int64_t pad_rows = (alignment - (rows % alignment)) % alignment; // Compute padding size
if (pad_rows == 0) {
return tensor; // Already aligned
}
torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options());
torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); // Pad along rows
// Ensure column-major layout
if (is_column_major) {
return tensor_padded.t().contiguous().t();
}
return tensor_padded;
}
[build-system]
requires = [
"scikit-build-core>=0.10",
"torch>=2.7.0",
"wheel",
]
build-backend = "scikit_build_core.build"
[project]
name = "lightx2v-kernel"
version = "0.0.1"
description = "Kernel Library for lightx2v"
readme = "README.md"
requires-python = ">=3.9"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA"
]
dependencies = []
[project.urls]
"Homepage" = ""
"Bug Tracker" = ""
[tool.wheel]
exclude = [
"dist*",
"tests*",
]
[tool.scikit-build]
cmake.build-type = "Release"
minimum-version = "build-system.requires"
wheel.py-api = "cp39"
wheel.license-files = []
wheel.packages = ["python/lightx2v_kernel"]
import ctypes
import os
import platform
import torch
SYSTEM_ARCH = platform.machine()
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
if os.path.exists(cuda_path):
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
from lightx2v_kernel import common_ops
from lightx2v_kernel.gemm import cutlass_scaled_fp4_mm, scaled_fp4_quant
from lightx2v_kernel.version import __version__
build_tree_kernel = (
None
)
import torch
def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
"""
mat_a: (m, k) cutlass::float_e2m1_t
mat_b: (n, k) cutlass::float_e2m1_t
scales_a: (m, 1) cutlass::float_ue4m3_t
scales_b: (n, 1) cutlass::float_ue4m3_t
alpha: (1, 1) float
bias: (m, n) cutlass::bfloat16_t
"""
m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default(
out, mat_a, mat_b, scales_a, scales_b, alpha, bias
)
return out
def scaled_fp4_quant(
input: torch.Tensor, input_global_scale: torch.Tensor
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
This function quantizes the last dimension of the given tensor `input`. For
every 16 consecutive elements, a single dynamically computed scaling factor
is shared. This scaling factor is quantized using the `input_global_scale`
and is stored in a swizzled layout (see
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
Args:
input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in a sizzled layout.
"""
# assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}."
# other_dims = 1 if input.ndim == 1 else -1
# input = input.reshape(other_dims, input.shape[-1])
m, n = input.shape
block_size = 16
device = input.device
# assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
# assert input.dtype in (
# torch.float16,
# torch.bfloat16,
# ), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
# Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
# We use the rounded values to store the swizzled values. Then, the scaling
# factors in float8_e4m3fn are packed into an int32 for every 4 values.
# rounded_m = ((m + 128 - 1) // 128) * 128
# scale_n = n // block_size
# rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale = torch.empty(
(((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32
)
torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(
output, input, output_scale, input_global_scale
)
output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale
import functools
from typing import Dict, Tuple
import torch
def get_cuda_stream() -> int:
return torch.cuda.current_stream().cuda_stream
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
key = (name, device)
buf = _cache_buf.get(key)
if buf is None:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
return buf
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)
@functools.lru_cache(maxsize=1)
def is_hopper_arch() -> bool:
# Hopper arch's compute capability == 9.0
device = torch.cuda.current_device()
major, minor = torch.cuda.get_device_capability(device)
return major == 9
import torch
BLOCK_SIZE = 16
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def cast_to_fp4(x):
sign = torch.sign(x)
x = torch.abs(x)
x[(x >= 0.0) & (x <= 0.25)] = 0.0
x[(x > 0.25) & (x < 0.75)] = 0.5
x[(x >= 0.75) & (x <= 1.25)] = 1.0
x[(x > 1.25) & (x < 1.75)] = 1.5
x[(x >= 1.75) & (x <= 2.5)] = 2.0
x[(x > 2.5) & (x < 3.5)] = 3.0
x[(x >= 3.5) & (x <= 5.0)] = 4.0
x[x > 5.0] = 6.0
return x * sign
def get_reciprocal(x):
if isinstance(x, torch.Tensor):
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
elif isinstance(x, (float, int)):
return 0.0 if x == 0 else 1.0 / x
else:
raise TypeError("Input must be a float, int, or a torch.Tensor.")
def ref_nvfp4_quant(x, global_scale):
assert global_scale.dtype == torch.float32
assert x.ndim == 2
m, n = x.shape
x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
# output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
output_scale = global_scale * get_reciprocal(scale)
scaled_x = x.to(torch.float32) * output_scale
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
return cast_to_fp4(clipped_x), scale.squeeze(-1)
if __name__ == "__main__":
x = torch.randn(1, 16, dtype=torch.bfloat16).cuda()
print(f"x: {x}, {x.shape}")
global_scale = (6.0 * 448.0 / torch.max(torch.abs(x))).to(torch.float32).cuda()
quant_x, scale = ref_nvfp4_quant(x, global_scale)
print(f"quant_x: {quant_x}, {quant_x.shape}")
print(f"scale: {scale}, {scale.shape}")
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant, cutlass_scaled_fp4_mm
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
kE2M1ToFloatArray = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
]
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
sf_m, sf_k = a_sf_swizzled.shape
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def e2m1_to_fp32(int4_value):
signBit = int4_value & 0x8
int4_absValue = int4_value & 0x7
float_result = kE2M1ToFloatArray[int4_absValue]
if signBit:
float_result = -float_result
return float_result
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
a = a.flatten()
# Get upper 4 bits
highHalfByte = (a & 0xF0) >> 4
# Get lower 4 bits
lowHalfByte = a & 0x0F
fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device)
fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2)
return out
def dequantize_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out
def get_ref_results(
a_fp4,
b_fp4,
a_sf,
b_sf,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
device,
):
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert m_k == n_k
a_in_dtype = dequantize_to_dtype(
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
)
b_in_dtype = dequantize_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
)
return torch.matmul(a_in_dtype, b_in_dtype.t())
@torch.inference_mode()
def test_nvfp4_gemm(
dtype: torch.dtype,
shape: tuple[int, int],
) -> None:
m, n, packed_k = shape
k = packed_k * 2
block_size = 16
a_dtype = torch.randn((m, k), dtype=dtype, device="cuda")
b_dtype = torch.randn((n, k), dtype=dtype, device="cuda")
bias = torch.randn((1, n), dtype=dtype, device="cuda")
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
print(f"a_global_scale : {a_global_scale}, {a_global_scale.shape}")
print(f"b_global_scale : {b_global_scale}, {b_global_scale.shape}")
alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
expected_out = get_ref_results(
a_fp4,
b_fp4,
a_scale_interleaved,
b_scale_interleaved,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
"cuda",
)
expected_out = expected_out + bias
print(f"alpha {alpha}, {alpha.shape}, {alpha.dtype}")
out = cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias
)
print(f"out : {out}, {out.shape}, {out.dtype}")
print(f"expected_out : {expected_out}, {expected_out.shape}, {expected_out.dtype}")
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)
if __name__ == "__main__":
test_nvfp4_gemm(torch.bfloat16, (128, 512, 128))
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant, cutlass_scaled_fp4_mm
import time
class MMWeightFp4:
def __init__(self, weight, bias):
self.load_fp4_weight(weight, bias)
self.act_quant_func = self.act_quant_fp4
# calibrate x_max
self.calibrate_x_absmax()
@torch.no_grad()
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@torch.no_grad()
def load_fp4_weight(self, weight, bias):
self.weight_global_scale = (2688.0 / torch.max(torch.abs(weight))).to(torch.float32)
self.weight, self.weight_scale = scaled_fp4_quant(weight, self.weight_global_scale)
self.bias = bias
def calibrate_x_absmax(self):
self.x_absmax = torch.tensor(5.0, dtype=torch.float32, device=self.weight.device) # need to be calibrated
self.input_global_scale = (2688.0 / self.x_absmax).to(torch.float32)
self.alpha = 1.0 / (self.input_global_scale * self.weight_global_scale)
@torch.no_grad()
def act_quant_fp4(self, x):
return scaled_fp4_quant(x, self.input_global_scale)
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
mm = MMWeightFp4(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
# warmup
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightFp4(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i+1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment