/****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ > struct Fragment_base_ { // The data type. using Data_type = Data_type_; // default input type using Input_type_ = Data_type_; // Does it store the array of elements. static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8; // The number of elements. static constexpr int NUM_ELTS = NUM_ELTS_; // The size of element in bits. static constexpr int BITS_PER_ELT = BITS_PER_ELT_; // The size of byte of a single register. static constexpr int BYTES_PER_REG = 4; // The size in bits. static constexpr int BITS_PER_REG = BYTES_PER_REG * 8; // The number of registers needed to store the fragment. static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG); // The size in bytes (as returned by sizeof(Fragment_base<>). static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG; // The alignment. static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< // The type of the elements. typename Data_type_, // The number of elements. int NUM_ELTS_, // The alignment if you want to force a value -- use 0 otherwise. int ALIGNMENT_ = 0, // The base class. typename Base_ = Fragment_base_ > struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { // The size of a load/store. static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t); // Clear the fragment. Using PTX in that code seems to produce better SASS... inline __device__ void clear() { #pragma unroll for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : ); } } // Immutable access to a register. inline __device__ const uint32_t& reg(int ii) const { return this->regs_[ii]; } // Mutable access to a register. inline __device__ uint32_t& reg(int ii) { return this->regs_[ii]; } uint32_t regs_[Base_::NUM_REGS]; // Immutable access to the elements. inline __device__ const Data_type_& elt(int ii) const { return reinterpret_cast(&this->regs_[0])[ii]; } // Mutable access to the elements. inline __device__ Data_type_& elt(int ii) { return reinterpret_cast(&this->regs_[0])[ii]; } // Immutable access to the elements with a cast. template< typename Cast_type > inline __device__ const Cast_type& elt_as(int ii) const { return reinterpret_cast(&this->regs_[0])[ii]; } // Mutable access to the elements. template< typename Cast_type > inline __device__ Cast_type& elt_as(int ii) { return reinterpret_cast(&this->regs_[0])[ii]; } // Add another fragment. inline __device__ void add(const Fragment &other) { // TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS? // Also are we doing int addition or __half2 addition? #pragma unroll for( int ii = 0; ii < NUM_ELTS_; ++ii ) { this->elt(ii) += other.elt(ii); } } // Multiply by another fragment. inline __device__ void hmul(const Fragment &other) { #pragma unroll for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); } } inline __device__ void hrelu_() { #pragma unroll for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { this->reg(ii) = fmha::hrelu2(this->reg(ii)); } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Layout > struct Fragment_a : public Fragment { }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Layout > struct Fragment_b : public Fragment { }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Fragment_accumulator : public Fragment { // The base class. using Base = Fragment; // Add two fragments. template< typename Other_fragment_ > inline __device__ void add(const Other_fragment_ &other) { for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { this->elt(ii) = this->elt(ii) + other.elt(ii); } } inline __device__ void mul_(const float other) { for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { this->elt(ii) *= other; } } // Do the HMMA. template< typename Layout_a, typename Layout_b > inline __device__ void mma(const Fragment_a &a, const Fragment_b &b) { asm volatile( \ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ " {%0, %1, %2, %3}, \n" \ " {%4, %5, %6, %7}, \n" \ " {%8, %9}, \n" \ " {%0, %1, %2, %3}; \n" \ : "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3)) : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) , "r"(b.reg(0)), "r"(b.reg(1))); asm volatile( \ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ " {%0, %1, %2, %3}, \n" \ " {%4, %5, %6, %7}, \n" \ " {%8, %9}, \n" \ " {%0, %1, %2, %3}; \n" \ : "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7)) : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) , "r"(b.reg(2)), "r"(b.reg(3))); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Fragment, int M, int N > inline __device__ void clear(Fragment (&frag)[M][N]) { #pragma unroll for( int mi = 0; mi < M; ++mi ) { #pragma unroll for( int ni = 0; ni < N; ++ni ) { frag[mi][ni].clear(); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Accumulator_type, int WARPS_K > struct Clear_accumulator { }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< int WARPS_K > struct Clear_accumulator { template< typename Acc, int M, int N > static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { fmha::clear(acc); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { #pragma unroll for( int mi = 0; mi < M; ++mi ) { #pragma unroll for( int ni = 0; ni < N; ++ni ) { acc[mi][ni].mma(a[mi], b[ni]); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template< // The number of rows in the CTA tile. int M_, // The number of cols in the CTA tile. int N_, // The number of elements in the the K dimension of the GEMM loop. int K_, // The number of rows of warps. int WARPS_M_, // The number of cols of warps. int WARPS_N_, // The number of warps in the K dimension of the GEMM loop. int WARPS_K_> struct Cta_tile_ { static constexpr int M = M_, N = N_, K = K_; // The number of warps. static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_; // The number of warps per CTA. static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K; // The number of threads per warp. static constexpr int THREADS_PER_WARP = 32; // The number of threads per CTA. static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Hmma_tile { // The number of elements computed with a single warp-MMA. static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16; // The number of elements computed with a single CTA-MMA. static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K; // The number of MMAs needed to compute the GEMM. static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA), MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA), MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA); // // The number of elements computed per warp. // static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA, // N_PER_WARP = MMAS_N * N_PER_MMA, // K_PER_WARP = MMAS_K * K_PER_MMA; }; //////////////////////////////////////////////////////////////////////////////////////////////////// using A_type = uint16_t; using B_type = uint16_t; using C_type = uint16_t; using Accumulator_type = float; using Epilogue_type = float; constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8; constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8; constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8; //////////////////////////////////////////////////////////////////////////////////////////////////// template using Cta_tile_extd = Cta_tile_; //////////////////////////////////////////////////////////////////////////////////////////////////// template using Cta_tile_with_k_with_padding = Cta_tile_extd::VALUE, Cta_tile_::WARPS_M, Cta_tile_::WARPS_N, Cta_tile_::WARPS_K>; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha