Unverified Commit eb34783c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Overhaul the compilation for the arch-specific features (#2279)



* Added sm_120f to the build
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Change the arch specific handling
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Support for CUDA<12.9
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Moved through the rest of the files
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Common cases
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Remove pure 100 from the list
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* CMake changes, (not yet working)
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Do not pass the arch-specific thing from build_tools
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Moved some of the files to arch-specific compilation
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix and also changing the order of compilation to hopefully get the
compilation time lower
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix for the files overwriting custom compile properties
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Actually make this whole thing work
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add space to the error message
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

* Fixes from review
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Changing the naming to be more intuitive
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add missing cassert include for device-side asserts
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
parent 66acb8e9
...@@ -257,11 +257,9 @@ def cuda_archs() -> str: ...@@ -257,11 +257,9 @@ def cuda_archs() -> str:
if archs is None: if archs is None:
version = cuda_version() version = cuda_version()
if version >= (13, 0): if version >= (13, 0):
archs = "75;80;89;90;100;100a;103a;120" archs = "75;80;89;90;100;120"
elif version >= (12, 9):
archs = "70;80;89;90;100;100a;103a;120"
elif version >= (12, 8): elif version >= (12, 8):
archs = "70;80;89;90;100;100a;120" archs = "70;80;89;90;100;120"
else: else:
archs = "70;80;89;90" archs = "70;80;89;90"
return archs return archs
......
...@@ -5,15 +5,6 @@ ...@@ -5,15 +5,6 @@
cmake_minimum_required(VERSION 3.21) cmake_minimum_required(VERSION 3.21)
# Language options # Language options
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON) set(CMAKE_CUDA_STANDARD_REQUIRED ON)
...@@ -30,8 +21,62 @@ project(transformer_engine LANGUAGES CUDA CXX) ...@@ -30,8 +21,62 @@ project(transformer_engine LANGUAGES CUDA CXX)
# CUDA Toolkit # CUDA Toolkit
find_package(CUDAToolkit REQUIRED) find_package(CUDAToolkit REQUIRED)
if (CUDAToolkit_VERSION VERSION_LESS 12.0) if (CUDAToolkit_VERSION VERSION_LESS 12.1)
message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") message(FATAL_ERROR "CUDA 12.1+ is required, but found CUDA ${CUDAToolkit_VERSION}")
endif()
# Process GPU architectures
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures
set(NVTE_GENERIC_ARCHS)
set(NVTE_SPECIFIC_ARCHS)
# Check for architecture 100
list(FIND CMAKE_CUDA_ARCHITECTURES "100" arch_100_index)
if(NOT arch_100_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "100")
list(APPEND NVTE_GENERIC_ARCHS "100")
list(APPEND NVTE_SPECIFIC_ARCHS "100a")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "103a")
endif()
endif()
# Check for architecture 101 (if we see this we are in toolkit <= 12.9)
list(FIND CMAKE_CUDA_ARCHITECTURES "101" arch_101_index)
if(NOT arch_101_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "101")
list(APPEND NVTE_GENERIC_ARCHS "101")
list(APPEND NVTE_SPECIFIC_ARCHS "101a")
endif()
# Check for architecture 110 (if we see this we are in toolkit >= 13.0)
list(FIND CMAKE_CUDA_ARCHITECTURES "110" arch_110_index)
if(NOT arch_110_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110")
list(APPEND NVTE_GENERIC_ARCHS "110")
list(APPEND NVTE_SPECIFIC_ARCHS "110f")
endif()
# Check for architecture 120
list(FIND CMAKE_CUDA_ARCHITECTURES "120" arch_120_index)
if(NOT arch_120_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "120")
list(APPEND NVTE_GENERIC_ARCHS "120")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "120f")
else()
list(APPEND NVTE_SPECIFIC_ARCHS "120a")
endif()
endif() endif()
# cuDNN frontend API # cuDNN frontend API
...@@ -78,9 +123,28 @@ endif() ...@@ -78,9 +123,28 @@ endif()
# Configure Transformer Engine library # Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..) include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES) set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES set(transformer_engine_cpp_sources)
set(transformer_engine_cuda_sources)
set(transformer_engine_cuda_arch_specific_sources)
list(APPEND transformer_engine_cpp_sources
cudnn_utils.cpp cudnn_utils.cpp
transformer_engine.cpp transformer_engine.cpp
fused_attn/fused_attn.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp)
list(APPEND transformer_engine_cuda_sources
common.cu common.cu
multi_tensor/adam.cu multi_tensor/adam.cu
multi_tensor/compute_scale.cu multi_tensor/compute_scale.cu
...@@ -92,40 +156,23 @@ list(APPEND transformer_engine_SOURCES ...@@ -92,40 +156,23 @@ list(APPEND transformer_engine_SOURCES
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu
dropout/dropout.cu dropout/dropout.cu
fused_attn/flash_attn.cu fused_attn/flash_attn.cu
fused_attn/context_parallel.cu fused_attn/context_parallel.cu
fused_attn/kv_cache.cu fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu permutation/permutation.cu
util/cast.cu
util/padding.cu util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
...@@ -139,12 +186,58 @@ list(APPEND transformer_engine_SOURCES ...@@ -139,12 +186,58 @@ list(APPEND transformer_engine_SOURCES
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu recipe/fp8_block_scaling.cu
recipe/nvfp4.cu recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/hadamard_transform_cast_fusion.cu)
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp # Compiling the files with the worst compilation time first to hopefully overlap
comm_gemm_overlap/userbuffers/userbuffers.cu # better with the faster-compiling cpp files
comm_gemm_overlap/comm_gemm_overlap.cpp) list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources}
${transformer_engine_cuda_sources}
${transformer_engine_cpp_sources})
# Set compile options for CUDA sources with generic architectures
foreach(cuda_source IN LISTS transformer_engine_cuda_sources)
set(arch_compile_options)
foreach(arch IN LISTS NVTE_GENERIC_ARCHS)
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
endforeach()
if(arch_compile_options)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS ${arch_compile_options}
)
endif()
endforeach()
# Set compile options for CUDA sources with specific architectures
foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources)
set(arch_compile_options)
foreach(arch IN LISTS NVTE_SPECIFIC_ARCHS)
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
endforeach()
if(arch_compile_options)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS ${arch_compile_options}
)
endif()
endforeach()
if (NVTE_WITH_CUBLASMP) if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES list(APPEND transformer_engine_SOURCES
...@@ -249,7 +342,8 @@ target_include_directories(transformer_engine PRIVATE ...@@ -249,7 +342,8 @@ target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers") "${CMAKE_CURRENT_BINARY_DIR}/string_headers")
# Compiler options # Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set(nvte_sources_with_fast_math)
list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu multi_tensor/adam.cu
...@@ -259,18 +353,24 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu ...@@ -259,18 +353,24 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
multi_tensor/sgd.cu multi_tensor/sgd.cu
fused_attn/flash_attn.cu fused_attn/flash_attn.cu
fused_attn/context_parallel.cu fused_attn/context_parallel.cu
fused_attn/kv_cache.cu fused_attn/kv_cache.cu)
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
util/cast.cu util/cast.cu)
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
endif() endif()
foreach(cuda_source IN LISTS nvte_sources_with_fast_math)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS "--use_fast_math")
endforeach()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
......
...@@ -97,7 +97,8 @@ cutlass::Array<cutlass::float_e2m1_t, 8> ...@@ -97,7 +97,8 @@ cutlass::Array<cutlass::float_e2m1_t, 8>
StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) { StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>; using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
result_type output; result_type output;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
auto output_ptr = reinterpret_cast<uint16_t *>(&output); auto output_ptr = reinterpret_cast<uint16_t *>(&output);
asm volatile( \ asm volatile( \
"{\n" \ "{\n" \
...@@ -109,10 +110,10 @@ StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::A ...@@ -109,10 +110,10 @@ StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::A
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]),
"f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]),
"r"(rbits[0]), "r"(rbits[1])); "r"(rbits[0]), "r"(rbits[1]));
#else } else {
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."); "Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL }
return output; return output;
} }
......
...@@ -264,7 +264,8 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s ...@@ -264,7 +264,8 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const uint32_t rbits) { const float2 in01, const float2 in23, const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
uint16_t out_4x; uint16_t out_4x;
asm volatile( asm volatile(
"{\n" "{\n"
...@@ -273,19 +274,20 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro ...@@ -273,19 +274,20 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro
: "=h"(out_4x) : "=h"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits));
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x);
#else } else {
NVTE_DEVICE_ERROR( NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. " "FP4 cvt.rs PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."); "Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0; uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL }
} }
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23, const float2 in23,
const uint32_t rbits) { const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY;
if constexpr (has_fp4) {
// NOTE: rbits unused for rn. // NOTE: rbits unused for rn.
uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing.
asm volatile( asm volatile(
...@@ -299,13 +301,13 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const floa ...@@ -299,13 +301,13 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const floa
: "=r"(out_4x) : "=r"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x));
return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0];
#else } else {
NVTE_DEVICE_ERROR( NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. " "FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."); "Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0; uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL }
} }
template <bool kApplyStochasticRounding> template <bool kApplyStochasticRounding>
......
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if CUDA_VERSION > 12080 #if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h> #include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080 #endif // FP4_TYPE_SUPPORTED
#include <cfloat> #include <cfloat>
#include "../common.h" #include "../common.h"
...@@ -30,7 +29,7 @@ ...@@ -30,7 +29,7 @@
namespace transformer_engine { namespace transformer_engine {
#if CUDA_VERSION > 12080 #if FP4_TYPE_SUPPORTED
namespace nvfp4_transpose { namespace nvfp4_transpose {
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() + using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
...@@ -152,12 +151,11 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int ...@@ -152,12 +151,11 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int
return rbits; return rbits;
} }
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
const uint64_t in_4x, const float2 scale, const uint32_t rbits) { const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0; uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
asm volatile( asm volatile(
"{\n" "{\n"
".reg.b64 v01; \n\t" ".reg.b64 v01; \n\t"
...@@ -185,20 +183,21 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_roun ...@@ -185,20 +183,21 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_roun
"}" "}"
: "=h"(out_4x) : "=h"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits)); : "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else } else {
NVTE_DEVICE_ERROR( NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. " "FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."); "Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL }
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x); return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
} }
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
const float2 scale, const float2 scale,
const uint32_t rbits) { const uint32_t rbits) {
// NOTE: rbits unused for rn. constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL if constexpr (is_blackwell) {
// NOTE: rbits unused for rn.
asm volatile( asm volatile(
"{\n" "{\n"
".reg.b64 v01; \n\t" ".reg.b64 v01; \n\t"
...@@ -230,11 +229,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64 ...@@ -230,11 +229,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64
"}" "}"
: "=r"(out_4x) : "=r"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale))); : "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
#else } else {
NVTE_DEVICE_ERROR( NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. " "FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."); "Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL }
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0]; return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
} }
...@@ -252,7 +251,8 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x ...@@ -252,7 +251,8 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0; uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
asm volatile( asm volatile(
"{\n" "{\n"
".reg.b64 v01; \n\t" ".reg.b64 v01; \n\t"
...@@ -275,11 +275,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_roun ...@@ -275,11 +275,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_roun
: "l"(reinterpret_cast<const uint64_t &>(in01)), : "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)), "l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits)); "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else } else {
NVTE_DEVICE_ERROR( NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. " "FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."); "Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL }
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x); return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
} }
...@@ -287,9 +287,10 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 ...@@ -287,9 +287,10 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
const float2 in23, const float2 in23,
const float2 scale, const float2 scale,
const uint32_t rbits) { const uint32_t rbits) {
// NOTE: rbits unused for rn. constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL if constexpr (is_blackwell) {
// NOTE: rbits unused for rn.
asm volatile( asm volatile(
"{\n" "{\n"
".reg.b64 v01; \n\t" ".reg.b64 v01; \n\t"
...@@ -316,11 +317,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 ...@@ -316,11 +317,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
: "l"(reinterpret_cast<const uint64_t &>(in01)), : "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)), "l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale))); "l"(reinterpret_cast<const uint64_t &>(scale)));
#else } else {
NVTE_DEVICE_ERROR( NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. " "FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."); "Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL }
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0]; return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
} }
...@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c ...@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
} }
} }
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &), template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE> typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM) __global__ void __launch_bounds__(THREADS_NUM)
...@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM) ...@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
} // namespace nvfp4_transpose } // namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080 #endif // FP4_TYPE_SUPPORTED
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &), template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
bool use_2d_quantization> bool use_2d_quantization>
void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
const QuantizationConfig *quant_config, cudaStream_t stream) { const QuantizationConfig *quant_config, cudaStream_t stream) {
#if CUDA_VERSION > 12080 #if FP4_TYPE_SUPPORTED
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
...@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o ...@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o
});); }););
#else #else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION > 12080 #endif // FP4_TYPE_SUPPORTED
} }
} // namespace transformer_engine } // namespace transformer_engine
......
This diff is collapsed.
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#endif #endif
#if !defined(__CUDACC_RTC__) #if !defined(__CUDACC_RTC__)
#include <cassert>
#include <cstdint> #include <cstdint>
#else #else
// Importing C++ standard headers is a pain with NVRTC // Importing C++ standard headers is a pain with NVRTC
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment