/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ /*! \file ptx.cuh * \brief BW PTX */ #ifndef TRANSFORMER_ENGINE_PTX_CUH_ #define TRANSFORMER_ENGINE_PTX_CUH_ #include #include namespace transformer_engine { namespace ptx { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory"); } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval __device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive __device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) : "memory"); } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count) : "memory"); } __device__ __forceinline__ void fence_mbarrier_init_release_cluster() { asm volatile("fence.mbarrier_init.release.cluster;"); } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // global -> shared::cluster __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( uint64_t *dst_shmem, const uint64_t *src_global_ptr, const uint32_t size, uint64_t *mbar) { uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); // triggers async copy, i.e. the thread continues until wait() on mbarrier // barrier condition: // - leader must arrive (i.e. 1 thread as set above) // - TMA hardware substracts bytes from expect_tx counter, must reach zero asm volatile( "cp.async.bulk.shared::cta.global" ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"(dst_shmem_ptr), "l"(src_global_ptr), "r"(size), "r"(mbar_ptr) : "memory"); } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // global -> shared::cluster __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y, uint64_t *mbar) { uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); // triggers async copy, i.e. the thread continues until wait() on mbarrier // barrier condition: // - leader must arrive (i.e. 1 thread as set above) // - TMA hardware substracts bytes from expect_tx counter, must reach zero asm volatile( "cp.async.bulk.tensor.2d.shared::cluster.global.tile" ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr), "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) : "memory"); } __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { uint32_t waitComplete; asm volatile( "{\n\t .reg .pred P_OUT; \n\t" "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" "selp.b32 %0, 1, 0, P_OUT; \n" "}" : "=r"(waitComplete) : "r"(mbar_ptr), "r"(parity) : "memory"); return static_cast(waitComplete); } __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { } } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, const uint64_t *src_shmem, const uint32_t size) { uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"(dst_global_ptr), "r"(src_shmem_ptr), "r"(size) : "memory"); } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y, uint64_t *src_shmem) { uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::"l"( tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr) : "memory"); } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group __device__ __forceinline__ void cp_async_bulk_wait_group() { asm volatile("cp.async.bulk.wait_group 0;"); } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group template __device__ __forceinline__ void cp_async_bulk_wait_group_read() { asm volatile("cp.async.bulk.wait_group.read 0;"); } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { asm volatile("cp.async.bulk.wait_group.read 0;"); } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { asm volatile("cp.async.bulk.wait_group.read 1;"); } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { asm volatile("cp.async.bulk.wait_group.read 2;"); } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { asm volatile("cp.async.bulk.wait_group.read 4;"); } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group __device__ __forceinline__ void cp_async_bulk_commit_group() { asm volatile("cp.async.bulk.commit_group;"); } // Proxy fence (bi-directional): __device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } __device__ __forceinline__ void fence_proxy_async_shared_cta() { asm volatile("fence.proxy.async.shared::cta;"); } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // namespace ptx namespace { template __forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool is_master_thread) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (is_master_thread) { // Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. #pragma unroll for (int iter = 0; iter < num_barriers; ++iter) { ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); } ptx::fence_proxy_async_shared_cta(); } // Syncthreads so initialized barrier is visible to all threads. __syncthreads(); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } template __forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_master_thread) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) // Destroy barrier. This invalidates the memory region of the barrier. If // further computations were to take place in the kernel, this allows the // memory location of the shared memory barrier to be reused. if (is_master_thread) { #pragma unroll for (int iter = 0; iter < num_barriers; ++iter) { ptx::mbarrier_invalid(&mbar[iter]); } } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (is_master_thread) { // Initiate bulk tensor copy ptx::cp_async_bulk_tensor_1d_global_to_shared(reinterpret_cast(dst), reinterpret_cast(src), num_bytes, barrier); // Arrive on the barrier and tell how many bytes are expected to come in. ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); } else { // Other threads just arrive ptx::mbarrier_arrive(barrier); } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, const size_t chunk_Y, const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (is_master_thread) { // Initiate bulk tensor copy ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), reinterpret_cast(src), chunk_X, chunk_Y, barrier); // Arrive on the barrier and tell how many bytes are expected to come in. ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); } else { // Other threads just arrive ptx::mbarrier_arrive(barrier); } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, void *dst2, const void *src2, const size_t chunk_X2, const size_t chunk_Y2, const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (is_master_thread) { // Initiate bulk tensor copy ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), reinterpret_cast(src), chunk_X1, chunk_Y1, barrier); ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), reinterpret_cast(src2), chunk_X2, chunk_Y2, barrier); // Arrive on the barrier and tell how many bytes are expected to come in. ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes); } else { // Other threads just arrive ptx::mbarrier_arrive(barrier); } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __forceinline__ __device__ void copy_2d_to_sharedx3( void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, void *dst2, const void *src2, const size_t chunk_X2, const size_t chunk_Y2, void *dst3, const void *src3, const size_t chunk_X3, const size_t chunk_Y3, const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (is_master_thread) { // Initiate bulk tensor copy ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), reinterpret_cast(src), chunk_X1, chunk_Y1, barrier); ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), reinterpret_cast(src2), chunk_X2, chunk_Y2, barrier); ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst3), reinterpret_cast(src3), chunk_X3, chunk_Y3, barrier); // Arrive on the barrier and tell how many bytes are expected to come in. ptx::mbarrier_arrive_expect_tx(barrier, 3 * num_bytes); } else { // Other threads just arrive ptx::mbarrier_arrive(barrier); } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } } // namespace } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_PTX_CUH_