Unverified Commit eb34783c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Overhaul the compilation for the arch-specific features (#2279)



* Added sm_120f to the build
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Change the arch specific handling
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Support for CUDA<12.9
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Moved through the rest of the files
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Common cases
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Remove pure 100 from the list
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* CMake changes, (not yet working)
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Do not pass the arch-specific thing from build_tools
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Moved some of the files to arch-specific compilation
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix and also changing the order of compilation to hopefully get the
compilation time lower
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix for the files overwriting custom compile properties
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Actually make this whole thing work
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add space to the error message
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

* Fixes from review
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Changing the naming to be more intuitive
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add missing cassert include for device-side asserts
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
parent 66acb8e9
...@@ -257,11 +257,9 @@ def cuda_archs() -> str: ...@@ -257,11 +257,9 @@ def cuda_archs() -> str:
if archs is None: if archs is None:
version = cuda_version() version = cuda_version()
if version >= (13, 0): if version >= (13, 0):
archs = "75;80;89;90;100;100a;103a;120" archs = "75;80;89;90;100;120"
elif version >= (12, 9):
archs = "70;80;89;90;100;100a;103a;120"
elif version >= (12, 8): elif version >= (12, 8):
archs = "70;80;89;90;100;100a;120" archs = "70;80;89;90;100;120"
else: else:
archs = "70;80;89;90" archs = "70;80;89;90"
return archs return archs
......
...@@ -5,15 +5,6 @@ ...@@ -5,15 +5,6 @@
cmake_minimum_required(VERSION 3.21) cmake_minimum_required(VERSION 3.21)
# Language options # Language options
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON) set(CMAKE_CUDA_STANDARD_REQUIRED ON)
...@@ -30,8 +21,62 @@ project(transformer_engine LANGUAGES CUDA CXX) ...@@ -30,8 +21,62 @@ project(transformer_engine LANGUAGES CUDA CXX)
# CUDA Toolkit # CUDA Toolkit
find_package(CUDAToolkit REQUIRED) find_package(CUDAToolkit REQUIRED)
if (CUDAToolkit_VERSION VERSION_LESS 12.0) if (CUDAToolkit_VERSION VERSION_LESS 12.1)
message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") message(FATAL_ERROR "CUDA 12.1+ is required, but found CUDA ${CUDAToolkit_VERSION}")
endif()
# Process GPU architectures
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures
set(NVTE_GENERIC_ARCHS)
set(NVTE_SPECIFIC_ARCHS)
# Check for architecture 100
list(FIND CMAKE_CUDA_ARCHITECTURES "100" arch_100_index)
if(NOT arch_100_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "100")
list(APPEND NVTE_GENERIC_ARCHS "100")
list(APPEND NVTE_SPECIFIC_ARCHS "100a")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "103a")
endif()
endif()
# Check for architecture 101 (if we see this we are in toolkit <= 12.9)
list(FIND CMAKE_CUDA_ARCHITECTURES "101" arch_101_index)
if(NOT arch_101_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "101")
list(APPEND NVTE_GENERIC_ARCHS "101")
list(APPEND NVTE_SPECIFIC_ARCHS "101a")
endif()
# Check for architecture 110 (if we see this we are in toolkit >= 13.0)
list(FIND CMAKE_CUDA_ARCHITECTURES "110" arch_110_index)
if(NOT arch_110_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110")
list(APPEND NVTE_GENERIC_ARCHS "110")
list(APPEND NVTE_SPECIFIC_ARCHS "110f")
endif()
# Check for architecture 120
list(FIND CMAKE_CUDA_ARCHITECTURES "120" arch_120_index)
if(NOT arch_120_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "120")
list(APPEND NVTE_GENERIC_ARCHS "120")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "120f")
else()
list(APPEND NVTE_SPECIFIC_ARCHS "120a")
endif()
endif() endif()
# cuDNN frontend API # cuDNN frontend API
...@@ -78,9 +123,28 @@ endif() ...@@ -78,9 +123,28 @@ endif()
# Configure Transformer Engine library # Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..) include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES) set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES set(transformer_engine_cpp_sources)
set(transformer_engine_cuda_sources)
set(transformer_engine_cuda_arch_specific_sources)
list(APPEND transformer_engine_cpp_sources
cudnn_utils.cpp cudnn_utils.cpp
transformer_engine.cpp transformer_engine.cpp
fused_attn/fused_attn.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp)
list(APPEND transformer_engine_cuda_sources
common.cu common.cu
multi_tensor/adam.cu multi_tensor/adam.cu
multi_tensor/compute_scale.cu multi_tensor/compute_scale.cu
...@@ -92,40 +156,23 @@ list(APPEND transformer_engine_SOURCES ...@@ -92,40 +156,23 @@ list(APPEND transformer_engine_SOURCES
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu
dropout/dropout.cu dropout/dropout.cu
fused_attn/flash_attn.cu fused_attn/flash_attn.cu
fused_attn/context_parallel.cu fused_attn/context_parallel.cu
fused_attn/kv_cache.cu fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu permutation/permutation.cu
util/cast.cu
util/padding.cu util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
...@@ -139,12 +186,58 @@ list(APPEND transformer_engine_SOURCES ...@@ -139,12 +186,58 @@ list(APPEND transformer_engine_SOURCES
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu recipe/fp8_block_scaling.cu
recipe/nvfp4.cu recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/hadamard_transform_cast_fusion.cu)
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp # Compiling the files with the worst compilation time first to hopefully overlap
comm_gemm_overlap/userbuffers/userbuffers.cu # better with the faster-compiling cpp files
comm_gemm_overlap/comm_gemm_overlap.cpp) list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources}
${transformer_engine_cuda_sources}
${transformer_engine_cpp_sources})
# Set compile options for CUDA sources with generic architectures
foreach(cuda_source IN LISTS transformer_engine_cuda_sources)
set(arch_compile_options)
foreach(arch IN LISTS NVTE_GENERIC_ARCHS)
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
endforeach()
if(arch_compile_options)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS ${arch_compile_options}
)
endif()
endforeach()
# Set compile options for CUDA sources with specific architectures
foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources)
set(arch_compile_options)
foreach(arch IN LISTS NVTE_SPECIFIC_ARCHS)
list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}")
endforeach()
if(arch_compile_options)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS ${arch_compile_options}
)
endif()
endforeach()
if (NVTE_WITH_CUBLASMP) if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES list(APPEND transformer_engine_SOURCES
...@@ -249,28 +342,35 @@ target_include_directories(transformer_engine PRIVATE ...@@ -249,28 +342,35 @@ target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers") "${CMAKE_CURRENT_BINARY_DIR}/string_headers")
# Compiler options # Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set(nvte_sources_with_fast_math)
fused_softmax/scaled_upper_triang_masked_softmax.cu list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
multi_tensor/adam.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/compute_scale.cu multi_tensor/adam.cu
multi_tensor/l2norm.cu multi_tensor/compute_scale.cu
multi_tensor/scale.cu multi_tensor/l2norm.cu
multi_tensor/sgd.cu multi_tensor/scale.cu
fused_attn/flash_attn.cu multi_tensor/sgd.cu
fused_attn/context_parallel.cu fused_attn/flash_attn.cu
fused_attn/kv_cache.cu fused_attn/context_parallel.cu
PROPERTIES fused_attn/kv_cache.cu)
COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
util/cast.cu util/cast.cu)
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
endif() endif()
foreach(cuda_source IN LISTS nvte_sources_with_fast_math)
set_property(
SOURCE ${cuda_source}
APPEND
PROPERTY
COMPILE_OPTIONS "--use_fast_math")
endforeach()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
......
...@@ -97,22 +97,23 @@ cutlass::Array<cutlass::float_e2m1_t, 8> ...@@ -97,22 +97,23 @@ cutlass::Array<cutlass::float_e2m1_t, 8>
StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) { StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>; using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
result_type output; result_type output;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
auto output_ptr = reinterpret_cast<uint16_t *>(&output); if constexpr (has_rs) {
asm volatile( \ auto output_ptr = reinterpret_cast<uint16_t *>(&output);
"{\n" \ asm volatile( \
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \ "{\n" \
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \ "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \
"}" \ "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \
: "=h"(output_ptr[0]), "}" \
: "=h"(output_ptr[0]),
"=h"(output_ptr[1]) "=h"(output_ptr[1])
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]),
"f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]),
"r"(rbits[0]), "r"(rbits[1])); "r"(rbits[0]), "r"(rbits[1]));
#else } else {
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."); "Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL }
return output; return output;
} }
......
...@@ -264,48 +264,50 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s ...@@ -264,48 +264,50 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const uint32_t rbits) { const float2 in01, const float2 in23, const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
uint16_t out_4x; if constexpr (has_rs) {
asm volatile( uint16_t out_4x;
"{\n" asm volatile(
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" "{\n"
"}" "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t"
: "=h"(out_4x) "}"
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); : "=h"(out_4x)
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits));
#else return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x);
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt.rs PTX instructions are architecture-specific. "
uint16_t dummy = 0; "Try recompiling with sm_XXXa instead of sm_XXX.");
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); uint16_t dummy = 0;
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
}
} }
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23, const float2 in23,
const uint32_t rbits) { const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY;
// NOTE: rbits unused for rn. if constexpr (has_fp4) {
uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. // NOTE: rbits unused for rn.
asm volatile( uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing.
"{\n" asm volatile(
".reg.b8 f0; \n\t" "{\n"
".reg.b8 f1; \n\t" ".reg.b8 f0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" ".reg.b8 f1; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t" "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t"
"}" "mov.b32 %0, {f0, f1, f0, f1};\n\t"
: "=r"(out_4x) "}"
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); : "=r"(out_4x)
return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x));
#else return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0];
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt PTX instructions are architecture-specific. "
uint16_t dummy = 0; "Try recompiling with sm_XXXa instead of sm_XXX.");
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); uint16_t dummy = 0;
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
}
} }
template <bool kApplyStochasticRounding> template <bool kApplyStochasticRounding>
......
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if CUDA_VERSION > 12080 #if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h> #include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080 #endif // FP4_TYPE_SUPPORTED
#include <cfloat> #include <cfloat>
#include "../common.h" #include "../common.h"
...@@ -30,7 +29,7 @@ ...@@ -30,7 +29,7 @@
namespace transformer_engine { namespace transformer_engine {
#if CUDA_VERSION > 12080 #if FP4_TYPE_SUPPORTED
namespace nvfp4_transpose { namespace nvfp4_transpose {
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() + using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
...@@ -152,89 +151,89 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int ...@@ -152,89 +151,89 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int
return rbits; return rbits;
} }
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
const uint64_t in_4x, const float2 scale, const uint32_t rbits) { const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0; uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
asm volatile( if constexpr (has_rs) {
"{\n" asm volatile(
".reg.b64 v01; \n\t" "{\n"
".reg.b64 v23; \n\t" ".reg.b64 v01; \n\t"
".reg.b16 v0_bf16; \n\t" ".reg.b64 v23; \n\t"
".reg.b16 v1_bf16; \n\t" ".reg.b16 v0_bf16; \n\t"
".reg.b16 v2_bf16; \n\t" ".reg.b16 v1_bf16; \n\t"
".reg.b16 v3_bf16; \n\t" ".reg.b16 v2_bf16; \n\t"
".reg.b32 v0; \n\t" ".reg.b16 v3_bf16; \n\t"
".reg.b32 v1; \n\t" ".reg.b32 v0; \n\t"
".reg.b32 v2; \n\t" ".reg.b32 v1; \n\t"
".reg.b32 v3; \n\t" ".reg.b32 v2; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" ".reg.b32 v3; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t" "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t" "cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t" "cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t" "cvt.f32.bf16 v2, v2_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t" "cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v23, {v2, v3}; \n\t" "mov.b64 v01, {v0, v1}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order "mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t" "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v3, v2}, v23; \n\t" "mov.b64 {v1, v0}, v01; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order "mov.b64 {v3, v2}, v23; \n\t"
"}" "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order
: "=h"(out_4x) "}"
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits)); : "=h"(out_4x)
#else : "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL "Try recompiling with sm_XXXa instead of sm_XXX.");
}
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x); return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
} }
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
const float2 scale, const float2 scale,
const uint32_t rbits) { const uint32_t rbits) {
// NOTE: rbits unused for rn. constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL if constexpr (is_blackwell) {
asm volatile( // NOTE: rbits unused for rn.
"{\n" asm volatile(
".reg.b64 v01; \n\t" "{\n"
".reg.b64 v23; \n\t" ".reg.b64 v01; \n\t"
".reg.b16 v0_bf16; \n\t" ".reg.b64 v23; \n\t"
".reg.b16 v1_bf16; \n\t" ".reg.b16 v0_bf16; \n\t"
".reg.b16 v2_bf16; \n\t" ".reg.b16 v1_bf16; \n\t"
".reg.b16 v3_bf16; \n\t" ".reg.b16 v2_bf16; \n\t"
".reg.b32 v0; \n\t" ".reg.b16 v3_bf16; \n\t"
".reg.b32 v1; \n\t" ".reg.b32 v0; \n\t"
".reg.b32 v2; \n\t" ".reg.b32 v1; \n\t"
".reg.b32 v3; \n\t" ".reg.b32 v2; \n\t"
".reg.b8 f0; \n\t" ".reg.b32 v3; \n\t"
".reg.b8 f1; \n\t" ".reg.b8 f0; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" ".reg.b8 f1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t" "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t" "cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t" "cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t" "cvt.f32.bf16 v2, v2_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t" "cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v23, {v2, v3}; \n\t" "mov.b64 v01, {v0, v1}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order "mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t" "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v3, v2}, v23; \n\t" "mov.b64 {v1, v0}, v01; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" "mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t" "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"}" "mov.b32 %0, {f0, f1, f0, f1};\n\t"
: "=r"(out_4x) "}"
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale))); : "=r"(out_4x)
#else : "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL "Try recompiling with sm_XXXa instead of sm_XXX.");
}
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0]; return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
} }
...@@ -252,34 +251,35 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x ...@@ -252,34 +251,35 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0; uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
asm volatile( if constexpr (has_rs) {
"{\n" asm volatile(
".reg.b64 v01; \n\t" "{\n"
".reg.b64 v23; \n\t" ".reg.b64 v01; \n\t"
".reg.b32 v0; \n\t" ".reg.b64 v23; \n\t"
".reg.b32 v1; \n\t" ".reg.b32 v0; \n\t"
".reg.b32 v2; \n\t" ".reg.b32 v1; \n\t"
".reg.b32 v3; \n\t" ".reg.b32 v2; \n\t"
"mov.b64 {v0, v1} , %1; \n\t" ".reg.b32 v3; \n\t"
"mov.b64 {v2, v3} , %2; \n\t" "mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 v01, {v0, v1}; \n\t" "mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v23, {v2, v3}; \n\t" "mov.b64 v01, {v0, v1}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order "mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t" "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v3, v2}, v23; \n\t" "mov.b64 {v1, v0}, v01; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order "mov.b64 {v3, v2}, v23; \n\t"
"}" "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order
: "=h"(out_4x) "}"
: "l"(reinterpret_cast<const uint64_t &>(in01)), : "=h"(out_4x)
"l"(reinterpret_cast<const uint64_t &>(in23)), : "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits)); "l"(reinterpret_cast<const uint64_t &>(in23)),
#else "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL "Try recompiling with sm_XXXa instead of sm_XXX.");
}
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x); return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
} }
...@@ -287,40 +287,41 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 ...@@ -287,40 +287,41 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
const float2 in23, const float2 in23,
const float2 scale, const float2 scale,
const uint32_t rbits) { const uint32_t rbits) {
// NOTE: rbits unused for rn. constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL if constexpr (is_blackwell) {
asm volatile( // NOTE: rbits unused for rn.
"{\n" asm volatile(
".reg.b64 v01; \n\t" "{\n"
".reg.b64 v23; \n\t" ".reg.b64 v01; \n\t"
".reg.b32 v0; \n\t" ".reg.b64 v23; \n\t"
".reg.b32 v1; \n\t" ".reg.b32 v0; \n\t"
".reg.b32 v2; \n\t" ".reg.b32 v1; \n\t"
".reg.b32 v3; \n\t" ".reg.b32 v2; \n\t"
".reg.b8 f0; \n\t" ".reg.b32 v3; \n\t"
".reg.b8 f1; \n\t" ".reg.b8 f0; \n\t"
"mov.b64 {v0, v1} , %1; \n\t" ".reg.b8 f1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t" "mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 v01, {v0, v1}; \n\t" "mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v23, {v2, v3}; \n\t" "mov.b64 v01, {v0, v1}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order "mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t" "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v3, v2}, v23; \n\t" "mov.b64 {v1, v0}, v01; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" "mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t" "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"}" "mov.b32 %0, {f0, f1, f0, f1};\n\t"
: "=r"(out_4x) "}"
: "l"(reinterpret_cast<const uint64_t &>(in01)), : "=r"(out_4x)
"l"(reinterpret_cast<const uint64_t &>(in23)), : "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(scale))); "l"(reinterpret_cast<const uint64_t &>(in23)),
#else "l"(reinterpret_cast<const uint64_t &>(scale)));
NVTE_DEVICE_ERROR( } else {
"FP4 cvt PTX instructions are architecture-specific. " NVTE_DEVICE_ERROR(
"Try recompiling with sm_XXXa instead of sm_XXX."); "FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL "Try recompiling with sm_XXXa instead of sm_XXX.");
}
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0]; return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
} }
...@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c ...@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
} }
} }
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &), template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE> typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM) __global__ void __launch_bounds__(THREADS_NUM)
...@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM) ...@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
} // namespace nvfp4_transpose } // namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080 #endif // FP4_TYPE_SUPPORTED
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &), template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
bool use_2d_quantization> bool use_2d_quantization>
void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
const QuantizationConfig *quant_config, cudaStream_t stream) { const QuantizationConfig *quant_config, cudaStream_t stream) {
#if CUDA_VERSION > 12080 #if FP4_TYPE_SUPPORTED
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
...@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o ...@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o
});); }););
#else #else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION > 12080 #endif // FP4_TYPE_SUPPORTED
} }
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -18,44 +18,165 @@ ...@@ -18,44 +18,165 @@
#include <cuda_fp4.h> #include <cuda_fp4.h>
#endif // CUDA_VERSION >= 12080 #endif // CUDA_VERSION >= 12080
#include "common/utils.cuh"
namespace transformer_engine { namespace transformer_engine {
namespace ptx { namespace ptx {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) template <int N>
struct ArchSpecific {
constexpr static int id = N * 10;
template <int CurrentArch, int ArchSpecific, int FamilySpecific>
constexpr static bool compatible() {
if constexpr (CurrentArch == id) {
static_assert(ArchSpecific == CurrentArch,
"Compiled for the generic architecture, while utilizing arch-specific "
"features. Please compile for smXXXa architecture instead of smXXX "
"architecture.");
return true;
} else {
return false;
}
}
};
template <int N>
struct FamilySpecific {
constexpr static int id = N * 10;
template <int CurrentArch, int ArchSpecific, int FamilySpecific>
constexpr static bool compatible() {
if constexpr ((CurrentArch / 100) == (id / 100)) {
static_assert(FamilySpecific == CurrentArch,
"Compiled for the generic architecture, while utilizing family-specific "
"features. Please compile for smXXXf architecture instead of smXXX "
"architecture.");
return true;
} else {
return false;
}
}
};
template <int Arch, int ArchSpecific, int FamilySpecific, class T, class... U>
constexpr bool is_supported_arch() {
if constexpr (T::template compatible<Arch, ArchSpecific, FamilySpecific>()) {
return true;
} else if constexpr (sizeof...(U) != 0) {
return is_supported_arch<Arch, ArchSpecific, FamilySpecific, U...>();
} else {
return false;
}
}
#if CUDA_VERSION < 12090
#if __CUDA_ARCH_HAS_FEATURE__(SM90_ALL)
#define __CUDA_ARCH_SPECIFIC__ 900
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 900
#endif
#if __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
#define __CUDA_ARCH_SPECIFIC__ 1000
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1000
#endif
#if __CUDA_ARCH_HAS_FEATURE__(SM101_ALL)
#define __CUDA_ARCH_SPECIFIC__ 1010
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1010
#endif
#if __CUDA_ARCH_HAS_FEATURE__(SM120_ALL)
#define __CUDA_ARCH_SPECIFIC__ 1200
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1200
#endif
#endif
#ifdef __CUDA_ARCH__
#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = __CUDA_ARCH__;
#else
#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = 0;
#endif
#ifdef __CUDA_ARCH_SPECIFIC__
#define __NVTE_ARCH_SPECIFIC__ constexpr int ArchSpecific = __CUDA_ARCH_SPECIFIC__;
#else
#define __NVTE_ARCH_SPECIFIC__ constexpr int ArchSpecific = 0;
#endif
#ifdef __CUDA_ARCH_FAMILY_SPECIFIC__
#define __NVTE_ARCH_FAMILY_SPECIFIC__ constexpr int FamilySpecific = __CUDA_ARCH_FAMILY_SPECIFIC__;
#else
#define __NVTE_ARCH_FAMILY_SPECIFIC__ constexpr int FamilySpecific = 0;
#endif
#define NVTE_CUDA_ARCH_MATCHES(...) \
[&] { \
__NVTE_CURRENT_ARCH__ \
__NVTE_ARCH_SPECIFIC__ \
__NVTE_ARCH_FAMILY_SPECIFIC__ \
return transformer_engine::ptx::is_supported_arch<current_arch, ArchSpecific, FamilySpecific, \
__VA_ARGS__>(); \
}();
#define ARCH_BLACKWELL_FAMILY \
NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>, ptx::FamilySpecific<110>, \
ptx::FamilySpecific<120>)
#define ARCH_HAS_STOCHASTIC_ROUNDING \
NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
__device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory"); asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory");
#else
NVTE_DEVICE_ERROR("mbarrier_init is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval
__device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { __device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory");
#else
NVTE_DEVICE_ERROR("mbarrier_invalid is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
__device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { __device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) : "memory"); asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) : "memory");
#else
NVTE_DEVICE_ERROR("mbarrier_arrive is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
__device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count) asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count)
: "memory"); : "memory");
#else
NVTE_DEVICE_ERROR("mbarrier_arrive_expect_tx is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void fence_mbarrier_init_release_cluster() { __device__ __forceinline__ void fence_mbarrier_init_release_cluster() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile("fence.mbarrier_init.release.cluster;"); asm volatile("fence.mbarrier_init.release.cluster;");
#else
NVTE_DEVICE_ERROR("fence_mbarrier_init_release_cluster is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// global -> shared::cluster // global -> shared::cluster
__device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared(
uint64_t *dst_shmem, const uint64_t *src_global_ptr, const uint32_t size, uint64_t *mbar) { uint64_t *dst_shmem, const uint64_t *src_global_ptr, const uint32_t size, uint64_t *mbar) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem);
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
// triggers async copy, i.e. the thread continues until wait() on mbarrier // triggers async copy, i.e. the thread continues until wait() on mbarrier
...@@ -67,6 +188,9 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( ...@@ -67,6 +188,9 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared(
".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"(dst_shmem_ptr), ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"(dst_shmem_ptr),
"l"(src_global_ptr), "r"(size), "r"(mbar_ptr) "l"(src_global_ptr), "r"(size), "r"(mbar_ptr)
: "memory"); : "memory");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_tensor_1d_global_to_shared is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
...@@ -74,6 +198,7 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( ...@@ -74,6 +198,7 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared(
__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x, uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x,
const uint32_t offset_y, uint64_t *mbar) { const uint32_t offset_y, uint64_t *mbar) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem);
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
// triggers async copy, i.e. the thread continues until wait() on mbarrier // triggers async copy, i.e. the thread continues until wait() on mbarrier
...@@ -85,9 +210,13 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( ...@@ -85,9 +210,13 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr), ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr),
"l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr)
: "memory"); : "memory");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_tensor_2d_global_to_shared is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t waitComplete; uint32_t waitComplete;
asm volatile( asm volatile(
"{\n\t .reg .pred P_OUT; \n\t" "{\n\t .reg .pred P_OUT; \n\t"
...@@ -98,15 +227,21 @@ __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, cons ...@@ -98,15 +227,21 @@ __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, cons
: "r"(mbar_ptr), "r"(parity) : "r"(mbar_ptr), "r"(parity)
: "memory"); : "memory");
return static_cast<bool>(waitComplete); return static_cast<bool>(waitComplete);
#else
NVTE_DEVICE_ERROR("mbarrier_try_wait_parity is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
return true;
} }
__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
} }
} #else
NVTE_DEVICE_ERROR("mbarrier_wait_parity is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_EXPONENT_BIAS = 127;
...@@ -121,55 +256,53 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { ...@@ -121,55 +256,53 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return __int_as_float(biased_exp << FP32_MANTISSA_BITS); return __int_as_float(biased_exp << FP32_MANTISSA_BITS);
} }
#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \
((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM103_ALL)))
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
uint16_t out; uint16_t out;
asm volatile( asm volatile(
"{\n" "{\n"
"cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n"
"}" "}"
: "=h"(out) : "=h"(out)
: "f"(val)); : "f"(val));
return *reinterpret_cast<e8m0_t *>(&out); return *reinterpret_cast<e8m0_t *>(&out);
#else } else {
// TODO: nan/inf needs to be set for any value // TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax. // of nan/inf in input not just amax.
if (isnan(val)) { if (isnan(val)) {
return 0xFF; return 0xFF;
} }
if (isinf(val)) { if (isinf(val)) {
return 0xFE; return 0xFE;
} }
if (val == 0.0f) { if (val == 0.0f) {
return 0x00; return 0x00;
} }
uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val); uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF; uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite. // Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent; ++exponent;
}
return exponent;
} }
return exponent;
#endif
} }
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global // shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr,
const uint64_t *src_shmem, const uint64_t *src_shmem,
const uint32_t size) { const uint32_t size) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem);
asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"(dst_global_ptr), asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"(dst_global_ptr),
"r"(src_shmem_ptr), "r"(size) "r"(src_shmem_ptr), "r"(size)
: "memory"); : "memory");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_tensor_1d_shared_to_global is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
...@@ -177,51 +310,93 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_ ...@@ -177,51 +310,93 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_
__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y, const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y,
uint64_t *src_shmem) { uint64_t *src_shmem) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem);
asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::"l"( asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::"l"(
tensor_map_ptr), tensor_map_ptr),
"r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr) "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr)
: "memory"); : "memory");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_tensor_2d_shared_to_global is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
__device__ __forceinline__ void cp_async_bulk_wait_group() { __device__ __forceinline__ void cp_async_bulk_wait_group() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("cp.async.bulk.wait_group 0;"); asm volatile("cp.async.bulk.wait_group 0;");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_wait_group is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
template <size_t W> template <size_t W>
__device__ __forceinline__ void cp_async_bulk_wait_group_read() { __device__ __forceinline__ void cp_async_bulk_wait_group_read() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("cp.async.bulk.wait_group.read 0;"); asm volatile("cp.async.bulk.wait_group.read 0;");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} }
template <> template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { __device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("cp.async.bulk.wait_group.read 0;"); asm volatile("cp.async.bulk.wait_group.read 0;");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} }
template <> template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { __device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("cp.async.bulk.wait_group.read 1;"); asm volatile("cp.async.bulk.wait_group.read 1;");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} }
template <> template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { __device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("cp.async.bulk.wait_group.read 2;"); asm volatile("cp.async.bulk.wait_group.read 2;");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} }
template <> template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("cp.async.bulk.wait_group.read 4;"); asm volatile("cp.async.bulk.wait_group.read 4;");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() { __device__ __forceinline__ void cp_async_bulk_commit_group() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("cp.async.bulk.commit_group;"); asm volatile("cp.async.bulk.commit_group;");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_commit_group is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} }
// Proxy fence (bi-directional): // Proxy fence (bi-directional):
__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } __device__ __forceinline__ void fence_proxy_async() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("fence.proxy.async;");
#else
NVTE_DEVICE_ERROR("fence_proxy_async is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
__device__ __forceinline__ void fence_proxy_async_shared_cta() { __device__ __forceinline__ void fence_proxy_async_shared_cta() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("fence.proxy.async.shared::cta;"); asm volatile("fence.proxy.async.shared::cta;");
#else
NVTE_DEVICE_ERROR("fence_proxy_async_shared_cta is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} }
template <typename T> template <typename T>
...@@ -282,15 +457,6 @@ static_assert(sizeof(fp4e2m1x2) == 1); ...@@ -282,15 +457,6 @@ static_assert(sizeof(fp4e2m1x2) == 1);
static_assert(sizeof(fp4e2m1x4) == 2); static_assert(sizeof(fp4e2m1x4) == 2);
#endif // CUDA_VERSION >= 12080 #endif // CUDA_VERSION >= 12080
// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1
// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6.
// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures:
// sm_100a
// sm_101a
// sm_120a
// When converting to .e2m1x2 data formats, the destination operand d has .b8 type. // When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format, // When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
// and the converted values are packed in the destination operand d such that the value // and the converted values are packed in the destination operand d such that the value
...@@ -313,6 +479,7 @@ __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, cons ...@@ -313,6 +479,7 @@ __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, cons
// SIMD like "Fused" cast + multiplication (x2) // SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
const floatx2 &scale) { const floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile( asm volatile(
"{\n" "{\n"
".reg.b64 val_pair; \n\t" ".reg.b64 val_pair; \n\t"
...@@ -325,10 +492,14 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, ...@@ -325,10 +492,14 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
: "=h"(reinterpret_cast<uint16_t &>(out)) : "=h"(reinterpret_cast<uint16_t &>(out))
: "l"(reinterpret_cast<const uint64_t &>(in)), : "l"(reinterpret_cast<const uint64_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale))); "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in, __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in,
const floatx2 &scale) { const floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile( asm volatile(
"{\n" "{\n"
".reg.b64 val_pair; \n\t" ".reg.b64 val_pair; \n\t"
...@@ -341,9 +512,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in, ...@@ -341,9 +512,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in,
: "=h"(reinterpret_cast<uint16_t &>(out)) : "=h"(reinterpret_cast<uint16_t &>(out))
: "l"(reinterpret_cast<const uint64_t &>(in)), : "l"(reinterpret_cast<const uint64_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale))); "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) { __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile( asm volatile(
"{\n" "{\n"
".reg.b64 val_pair_before; \n\t" ".reg.b64 val_pair_before; \n\t"
...@@ -363,9 +538,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, con ...@@ -363,9 +538,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, con
: "=h"(reinterpret_cast<uint16_t &>(out)) : "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)), : "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale))); "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) { __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile( asm volatile(
"{\n" "{\n"
".reg.b64 val_pair_before; \n\t" ".reg.b64 val_pair_before; \n\t"
...@@ -385,9 +564,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, con ...@@ -385,9 +564,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, con
: "=h"(reinterpret_cast<uint16_t &>(out)) : "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)), : "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale))); "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) { __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile( asm volatile(
"{\n" "{\n"
".reg.b64 val_pair_before; \n\t" ".reg.b64 val_pair_before; \n\t"
...@@ -407,9 +590,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, con ...@@ -407,9 +590,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, con
: "=h"(reinterpret_cast<uint16_t &>(out)) : "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)), : "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale))); "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) { __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile( asm volatile(
"{\n" "{\n"
".reg.b64 val_pair_before; \n\t" ".reg.b64 val_pair_before; \n\t"
...@@ -429,24 +616,33 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, con ...@@ -429,24 +616,33 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, con
: "=h"(reinterpret_cast<uint16_t &>(out)) : "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)), : "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale))); "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) { __device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;" asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;"
: "=r"(reinterpret_cast<uint32_t &>(dst)) : "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(reinterpret_cast<const uint32_t &>(p1)), : "r"(reinterpret_cast<const uint32_t &>(p1)),
"r"(reinterpret_cast<const uint32_t &>(p2))); "r"(reinterpret_cast<const uint32_t &>(p2)));
#else
NVTE_DEVICE_ERROR("abs_max_2x is only supported on SM 8.9+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
} }
__device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) { __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;" asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;"
: "=r"(reinterpret_cast<uint32_t &>(dst)) : "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(reinterpret_cast<const uint32_t &>(p1)), : "r"(reinterpret_cast<const uint32_t &>(p1)),
"r"(reinterpret_cast<const uint32_t &>(p2))); "r"(reinterpret_cast<const uint32_t &>(p2)));
#else
NVTE_DEVICE_ERROR("abs_max_2x is only supported on SM 8.9+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
} }
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx } // namespace ptx
namespace { namespace {
...@@ -464,6 +660,8 @@ __forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool i ...@@ -464,6 +660,8 @@ __forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool i
} }
// Syncthreads so initialized barrier is visible to all threads. // Syncthreads so initialized barrier is visible to all threads.
__syncthreads(); __syncthreads();
#else
NVTE_DEVICE_ERROR("initialize_barriers is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
...@@ -479,6 +677,8 @@ __forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_m ...@@ -479,6 +677,8 @@ __forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_m
ptx::mbarrier_invalid(&mbar[iter]); ptx::mbarrier_invalid(&mbar[iter]);
} }
} }
#else
NVTE_DEVICE_ERROR("destroy_barriers is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
...@@ -498,6 +698,8 @@ __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, ...@@ -498,6 +698,8 @@ __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src,
// Other threads just arrive // Other threads just arrive
ptx::mbarrier_arrive(barrier); ptx::mbarrier_arrive(barrier);
} }
#else
NVTE_DEVICE_ERROR("copy_1d_to_shared is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
...@@ -517,6 +719,8 @@ __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, co ...@@ -517,6 +719,8 @@ __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, co
// Other threads just arrive // Other threads just arrive
ptx::mbarrier_arrive(barrier); ptx::mbarrier_arrive(barrier);
} }
#else
NVTE_DEVICE_ERROR("copy_2d_to_shared is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
...@@ -543,6 +747,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, ...@@ -543,6 +747,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src,
// Other threads just arrive // Other threads just arrive
ptx::mbarrier_arrive(barrier); ptx::mbarrier_arrive(barrier);
} }
#else
NVTE_DEVICE_ERROR("copy_2d_to_sharedx2 is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
...@@ -572,6 +778,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx3( ...@@ -572,6 +778,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx3(
// Other threads just arrive // Other threads just arrive
ptx::mbarrier_arrive(barrier); ptx::mbarrier_arrive(barrier);
} }
#else
NVTE_DEVICE_ERROR("copy_2d_to_sharedx3 is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#endif #endif
#if !defined(__CUDACC_RTC__) #if !defined(__CUDACC_RTC__)
#include <cassert>
#include <cstdint> #include <cstdint>
#else #else
// Importing C++ standard headers is a pain with NVRTC // Importing C++ standard headers is a pain with NVRTC
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment