Commit c2b62b7f authored by JR_ZZU's avatar JR_ZZU 🌴
Browse files

delete origin files

parent 2a4864d5
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>;
template<bool Is_training>
__global__
void fmha_fprop_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}
void run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_384_64_sm80_kernel<true> : &fmha_fprop_fp16_384_64_sm80_kernel<false>;
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>;
template<bool Is_training>
__global__
void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int total_heads) {
fmha::device_1xN<Kernel_traits, Is_training>(params, total_heads);
}
template<bool Is_training>
__global__
void fmha_fprop_fp16_512_64_sm80_kernel_nl(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}
void run_fmha_fp16_512_64_sm80_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel<true> : &fmha_fprop_fp16_512_64_sm80_kernel<false>;
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
const int heads_total = launch_params.params.b * launch_params.params.h;
if(configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas);
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8;
launch_params.elts_per_thread = heads_per_cta * elts_per_head;
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
heads_total);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_fp16_512_64_sm80_nl_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel_nl<true> : &fmha_fprop_fp16_512_64_sm80_kernel_nl<false>;
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
if( launch_params.is_nl ) {
run_fmha_fp16_512_64_sm80_nl_(launch_params, configure);
} else {
run_fmha_fp16_512_64_sm80_(launch_params, configure);
}
}
/***************************************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits>
struct Gemm_Q_K_base {
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
using Fragment_q = typename Smem_tile_q::Fragment;
using Fragment_k = typename Smem_tile_k::Fragment;
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2;
__device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx)
: smem_q(smem_ptr_q, tidx)
, smem_k(smem_ptr_k, tidx) {
}
__device__ inline void load_q() {
smem_q.load(frag_q[0], 0);
}
__device__ inline void reload_q() {
smem_q.load(frag_q[0], 0);
}
Fragment_q frag_q[2][Mma_tile_p::MMAS_M];
Smem_tile_q smem_q;
Smem_tile_k smem_k;
};
template<typename Kernel_traits, bool K_in_regs>
struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Smem_tile_o = typename Base::Smem_tile_o;
using Smem_tile_q = typename Base::Smem_tile_q;
using Smem_tile_k = typename Base::Smem_tile_k;
using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p;
enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V };
enum { SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE };
enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) };
// Q | K / V
// | O | SOFTMAX
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
+ std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE,
Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX);
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
}
__device__ inline void load_k(){
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
Base::smem_k.load(frag_k[ki], ki);
}
}
template<typename Acc, int M, int N>
__device__ inline void operator()(Acc (&acc_p)[M][N]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
}
__device__ inline void reload_k(){
// Noop.
}
Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
};
template<typename Kernel_traits>
struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Smem_tile_o = typename Base::Smem_tile_o;
using Smem_tile_q = typename Base::Smem_tile_q;
using Smem_tile_k = typename Base::Smem_tile_k;
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p;
Fragment_k frag_k[2][Mma_tile_p::MMAS_N];
enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V };
enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) };
static_assert(Smem_tile_v::BYTES_PER_TILE == (int) Smem_tile_k::BYTES_PER_TILE);
enum { SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE };
// Q | K/V + O + SOFTMAX
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
+ (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE
+ Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX;
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
}
__device__ inline void load_k(){
Base::smem_k.load(frag_k[0], 0);
}
template<typename Acc, int M, int N>
__device__ inline void operator()(Acc (&acc_p)[M][N]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki);
Base::smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
}
__device__ inline void reload_k(){
Base::smem_k.load(frag_k[0], 0);
}
};
template<typename Kernel_traits>
constexpr size_t get_dynamic_smem_size(){
return Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>::SMEM_BYTES;
}
template<typename Kernel_traits, bool Is_training, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, const int begin, const int steps, Prng & ph) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
// The number of threads per row.
enum { THREADS_PER_ROW = 32 };
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() ) return;
Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
// Wind gmem tiles to the correct position.
for( int it = 0; it < begin; it++ ) {
gmem_q.move();
gmem_s.move();
gmem_o.move();
}
fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);
// Trigger the loads for K.
gmem_k.load(gemm_q_k.smem_k);
// Trigger the loads for Q.
gmem_q.load(gemm_q_k.smem_q);
// Trigger the loads for V.
gmem_v.load(smem_v);
const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
#pragma unroll
for(int it=0;it < Gmem_tile_k::LDGS;it++){
gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
}
// Commit the data for Q and V to shared memory.
gmem_q.commit(gemm_q_k.smem_q);
gmem_v.commit(smem_v);
// Commit the data for K to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_k.commit(gemm_q_k.smem_k);
}
__syncthreads();
// Load the fragments for Q.
gemm_q_k.load_q();
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_k.commit(gemm_q_k.smem_k);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for K.
gemm_q_k.load_k();
// Create the object to do the softmax.
Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
// Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) {
if(begin + l * Cta_tile_p::M >= binfo.actual_seqlen) break;
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k(acc_p);
// Trigger the load for the next Q values.
if( l < steps - 1) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(gemm_q_k.smem_q);
}
// Load the mask for that iteration.
mask.load(begin + l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack_noscale(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
//softmax.template reduce<fmha::Max_>(p_max);
softmax.reduce_max(p_max);
// Compute the exponential value.
softmax.apply_exp(p_max);
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.reduce_sum(p_sum);
// Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum);
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
}
}
}
softmax.pack(frag_p);
gmem_s.store(frag_p, mask);
gmem_s.move();
} else {
softmax.pack(frag_p);
}
// Commit the values for Q into shared memory.
if(l < steps - 1) {
gmem_q.commit(gemm_q_k.smem_q);
}
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
gemm_q_k.reload_k();
// Commit the values for Q into shared memory.
if(l < steps - 1) {
gemm_q_k.reload_q();
}
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_training, typename Params>
inline __device__ void device_1xN(const Params &params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
for( int it = 0; it < num_full_heads; it++ ) {
const int bidx = it * gridDim.x + blockIdx.x;
const int bidh = bidx % params.h;
const int bidb = bidx / params.h;
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, 0, STEPS, ph);
__syncthreads();
}
if( main_group_size == 0 )
return;
const int head_offset = num_full_heads * gridDim.x;
if( blockIdx.x < main_group_size * num_main_groups ) {
// process within heads
const int group = blockIdx.x % num_main_groups;
const int bidx = blockIdx.x / num_main_groups;
const int bidh = (head_offset + bidx) % params.h;
const int bidb = (head_offset + bidx) / params.h;
const int offset = group * main_steps;
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, offset, main_steps, ph);
} else {
if(rest_steps == 0 ) return;
// process across heads
const int bidx = blockIdx.x - main_group_size * num_main_groups;
const int offset = num_main_groups * main_steps;
const int total_heads = params.b * params.h;
const int rest_ctas = gridDim.x - main_group_size * num_main_groups;
for( int it = head_offset + bidx; it < total_heads; it += rest_ctas ) {
const int bidh = it % params.h;
const int bidb = it / params.h;
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, offset, rest_steps, ph);
__syncthreads();
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_training, typename Params>
inline __device__ void device_1xN(const Params &params, const int total_heads) {
const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
for(int bidx = blockIdx.x; bidx < total_heads; bidx += gridDim.x){
const int bidh = bidx % params.h;
const int bidb = bidx / params.h;
fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, 0, STEPS, ph);
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <multihead_attn/philox.cuh>
#include <fmha.h>
#include <fmha/utils.h>
#include <fmha/smem_tile.h>
#include <fmha/gmem_tile.h>
#include <fmha/mask.h>
#include <fmha/softmax.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_CTA>
struct BlockInfoPadded {
template<typename Params>
__device__ BlockInfoPadded(const Params &params,
const int bidb,
const int bidh,
const int tidx)
: bidb(bidb), bidh(bidh), h(params.h) {
// The block index.
sum_s = params.cu_seqlens[bidb];
actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;
bidx = sum_s * params.h + bidh;
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
}
__device__ bool stop_early() const {
return actual_seqlen == 0;
}
int actual_seqlen;
int bidx;
int sum_s;
int bidh;
int bidb;
int tidx_global;
int h;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int CHUNKS, typename Cta_tile>
struct Noloop_traits{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum{ STEP = Cta_tile::M };
enum{ SEQLEN = Cta_tile::N };
template<typename Block_info>
inline __device__ Noloop_traits(const int bidc, const Block_info& binfo)
: bidc_(bidc) {
const int seqlen = binfo.actual_seqlen;
const int steps = (seqlen + STEP - 1) / STEP;
const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS;
const int step_begin = bidc_ * steps_per_chunk;
const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk);
const int actual_steps = max(0, step_end - step_begin);
loop_offset_ = step_begin;
num_steps_ = actual_steps;
}
template<typename ... Tiles>
inline __device__ void move_all(Tiles & ... tiles) const {
using expand_type = int[];
for( int s = 0; s < loop_offset_; s++ ) {
expand_type{ (tiles.move(), 0)... };
}
}
inline __device__ int get_idx_dk() const {
//return bidc_;
return bidc_ * 2 + 0;
}
inline __device__ int get_idx_dv() const {
//return CHUNKS + bidc_;
return bidc_ * 2 + 1;
}
inline __device__ int offset_loop_count(const int l) {
// convert loop counter to position in the outer sequence
return (loop_offset_ + l) * STEP;
}
const uint32_t bidc_;
int loop_offset_;
int num_steps_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits>
std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const int heads_total) {
constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
const int num_full_heads = heads_total / total_ctas;
const int heads_last_wave = heads_total % total_ctas;
int num_main_groups = 0;
int main_steps = 0;
int rest_steps = 0;
if( heads_last_wave > 0 ) {
// Number of CTA groups that process within heads.
num_main_groups = total_ctas / heads_last_wave;
// Remaining CTAs that process between heads.
const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups);
if(rest_ctas == 0) {
// We have exactly "num_main_groups" CTAs to process each of the remaining heads.
main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups;
num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0
rest_steps = STEPS_PER_HEAD % main_steps;
} else {
// Ideal number of steps if we could load-balance as evenly as possible.
const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas;
// Iterations that a "rest" CTA has to do at most.
const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas;
// Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
main_steps = steps_ideal;
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
for( ; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++ ) {
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
const int max_rest_total_steps = rest_steps * max_rest_iters;
if( max_rest_total_steps < main_steps )
break;
}
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
}
}
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps);
const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8;
const int elts_per_thread = max_steps * elts_per_thread_per_step;
return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
inline __device__ float4 ldg128(const void *ptr) {
return *static_cast<const float4 *>(ptr);
}
inline __device__ void stg128(void *ptr, const float4 &data) {
*static_cast<float4 *>(ptr) = data;
}
template<typename T, int THREADS, int HIDDEN_SIZE, int CHUNKS>
__global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void *__restrict__ out,
const void *__restrict__ in,
const int *__restrict__ cu_seqlens,
const int batch_size) {
enum { BYTES_PER_LDG = 16 };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) };
// One CTA hidden vector for K and V
enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 };
// The stride in bytes in dQKV
enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) };
// The offset in bytes in dQKV to the dKV part for non-interleaved heads
enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) };
static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T));
// Size in bytes of the input tile
enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW };
enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG };
enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA };
static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW);
union Vec_t {
float4 raw;
T elt[NUM_ELTS];
};
// ZERO-OUT invalid positions in dQKV
const int total = cu_seqlens[batch_size];
if(blockIdx.x >= total){
enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) };
enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG };
const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f);
char *base_ptr = static_cast<char *>(out) + blockIdx.x * OUT_STRIDE_BYTES;
for(int tidx = threadIdx.x; tidx < STGS; tidx += THREADS){
stg128(base_ptr + tidx * BYTES_PER_LDG, zeros);
}
return;
}
// SETUP
const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG;
const char *ptr_in = static_cast<const char *>(in) + offset_in;
const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG;
char *ptr_out = static_cast<char *>(out) + OUT_OFFSET_KV_BYTES + offset_out;
// LOAD
Vec_t local_in[CHUNKS][LDGS];
#pragma unroll
for( int c = 0; c < CHUNKS; c++ ) {
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA;
local_in[c][l].raw = ldg128(ptr_in + offset);
}
}
// UNPACK
float acc[LDGS][NUM_ELTS];
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
acc[l][e] = float(local_in[0][l].elt[e]);
}
}
// COMPUTE
#pragma unroll
for( int c = 1; c < CHUNKS; c++ ) {
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
acc[l][e] += float(local_in[c][l].elt[e]);
}
}
}
// PACK
Vec_t local_out[LDGS];
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
local_out[l].elt[e] = T(acc[l][e]);
}
}
// STORE
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
const int offset = l * BYTES_PER_CTA;
stg128(ptr_out + offset, local_out[l].raw);
}
}
void fmha_run_noloop_reduce(void *out,
const void *in,
const int *cu_seqlens,
const int hidden_size,
const int batch_size,
const int total,
const int num_chunks,
cudaStream_t stream) {
const int blocks = total;
if(hidden_size == 1024){
constexpr int HIDDEN_SIZE = 1024;
constexpr int THREADS = 256;
if( num_chunks == 2 ) {
fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 2><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);
} else if( num_chunks == 3 ) {
fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 3><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);
} else {
assert(false && "Unsupported num_chunks");
}
}else{
assert(false && "Unsupported hidden_size");
}
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define FMHA_CHECK_CUDA( call ) \
do { \
cudaError_t status_ = call; \
if( status_ != cudaSuccess ) { \
fprintf( stderr, \
"CUDA error (%s:%d): %s\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString( status_ ) ); \
exit( 1 ); \
} \
} while( 0 )
////////////////////////////////////////////////////////////////////////////////////////////////////
enum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
if( dtype == DATA_TYPE_FP16 ) {
half x = __float2half_rn( norm );
uint16_t h = reinterpret_cast<const uint16_t &>( x );
ushort2 h2 = { h, h };
alpha = reinterpret_cast<const uint32_t &>( h2 );
} else if( dtype == DATA_TYPE_FP32 ) {
alpha = reinterpret_cast<const uint32_t &>( norm );
} else if( dtype == DATA_TYPE_INT32 ) {
int32_t inorm = static_cast<int32_t>( norm );
alpha = reinterpret_cast<const uint32_t &>( inorm );
} else {
assert( false );
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {
switch( dtype ) {
case DATA_TYPE_FP32:
return n * 4;
case DATA_TYPE_FP16:
return n * 2;
case DATA_TYPE_INT32:
return n * 4;
case DATA_TYPE_INT8:
return n;
default:
assert( false );
return 0;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#include <torch/torch.h>
#include <vector>
#include <cstdint>
// CUDA forward declarations
std::vector<at::Tensor> focal_loss_forward_cuda(
const at::Tensor &cls_output,
const at::Tensor &cls_targets_at_level,
const at::Tensor &num_positives_sum,
const int64_t num_real_classes,
const float alpha,
const float gamma,
const float smoothing_factor);
at::Tensor focal_loss_backward_cuda(
const at::Tensor &grad_output,
const at::Tensor &partial_grad,
const at::Tensor &num_positives_sum);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> focal_loss_forward(
const at::Tensor &cls_output,
const at::Tensor &cls_targets_at_level,
const at::Tensor &num_positives_sum,
const int64_t num_real_classes,
const float alpha,
const float gamma,
const float smoothing_factor
) {
CHECK_INPUT(cls_output);
CHECK_INPUT(cls_targets_at_level);
CHECK_INPUT(num_positives_sum);
return focal_loss_forward_cuda(
cls_output,
cls_targets_at_level,
num_positives_sum,
num_real_classes,
alpha,
gamma,
smoothing_factor);
}
at::Tensor focal_loss_backward(
const at::Tensor &grad_output,
const at::Tensor &partial_grad,
const at::Tensor &num_positives_sum
) {
CHECK_INPUT(grad_output);
CHECK_INPUT(partial_grad);
return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &focal_loss_forward,
"Focal loss calculation forward (CUDA)");
m.def("backward", &focal_loss_backward,
"Focal loss calculation backward (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#define ASSERT_UINT4_ALIGNED(PTR) \
TORCH_INTERNAL_ASSERT(is_aligned<uint4>(PTR), "Tensor " #PTR " is not uint4 aligned")
template <class T> bool is_aligned(const void *ptr) noexcept {
auto iptr = reinterpret_cast<std::uintptr_t>(ptr);
return !(iptr % alignof(T));
}
template <bool SMOOTHING, int ILP, typename scalar_t, typename labelscalar_t,
typename accscalar_t, typename outscalar_t>
__global__ void focal_loss_forward_cuda_kernel(
outscalar_t *loss, scalar_t *partial_grad,
const scalar_t *__restrict__ cls_output,
const labelscalar_t *__restrict__ cls_targets_at_level,
const float *__restrict__ num_positives_sum, const int64_t num_examples,
const int64_t num_classes, const int64_t num_real_classes,
const float alpha, const float gamma, const float smoothing_factor) {
extern __shared__ unsigned char shm[];
accscalar_t *loss_shm = reinterpret_cast<accscalar_t *>(shm);
loss_shm[threadIdx.x] = 0;
accscalar_t loss_acc = 0;
accscalar_t one = accscalar_t(1.0);
accscalar_t K = accscalar_t(2.0);
accscalar_t normalizer = one / static_cast<accscalar_t>(num_positives_sum[0]);
accscalar_t nn_norm, np_norm, pn_norm, pp_norm;
// *_norm is used for label smoothing only
if (SMOOTHING) {
nn_norm = one - smoothing_factor / K;
np_norm = smoothing_factor / K;
pn_norm = smoothing_factor - smoothing_factor / K;
pp_norm = one - smoothing_factor + smoothing_factor / K;
}
uint4 p_vec, grad_vec;
// Accumulate loss on each thread
for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP;
i < num_examples * num_classes; i += gridDim.x * blockDim.x * ILP) {
int64_t idy = i / num_classes;
labelscalar_t y = cls_targets_at_level[idy];
int64_t base_yid = i % num_classes;
int64_t pos_idx = idy * num_classes + y;
p_vec = *(uint4 *)&cls_output[i];
// Skip ignored matches
if (y == -2) {
#pragma unroll
for (int j = 0; j < ILP; j++) {
*((scalar_t *)(&grad_vec) + j) = 0;
}
*(uint4 *)&partial_grad[i] = grad_vec;
continue;
}
#pragma unroll
for (int j = 0; j < ILP; j++) {
// Skip the pad classes
if (base_yid + j >= num_real_classes) {
*((scalar_t *)(&grad_vec) + j) = 0;
continue;
}
accscalar_t p = static_cast<accscalar_t>(*((scalar_t *)(&p_vec) + j));
accscalar_t exp_np = ::exp(-p);
accscalar_t exp_pp = ::exp(p);
accscalar_t sigma = one / (one + exp_np);
accscalar_t logee = (p >= 0) ? exp_np : exp_pp;
accscalar_t addee = (p >= 0) ? 0 : -p;
accscalar_t off_a = addee + ::log(one + logee);
// Negative matches
accscalar_t base = SMOOTHING ? nn_norm * p : p;
accscalar_t off_b = (SMOOTHING ? np_norm : 0) - sigma;
accscalar_t coeff_f1 = one - alpha;
accscalar_t coeff_f2 = sigma;
accscalar_t coeff_b1 = gamma;
accscalar_t coeff_b2 = one - sigma;
// Positive matches
if (y >= 0 && (i + j == pos_idx)) {
base = SMOOTHING ? pn_norm * p : 0;
off_b = (SMOOTHING ? pp_norm : one) - sigma;
coeff_f1 = alpha;
coeff_f2 = one - sigma;
coeff_b1 = -gamma;
coeff_b2 = sigma;
}
accscalar_t coeff_f = coeff_f1 * ::pow(coeff_f2, gamma);
accscalar_t coeff_b = coeff_b1 * coeff_b2;
accscalar_t loss_t = coeff_f * (base + off_a);
accscalar_t grad = coeff_f * (coeff_b * (base + off_a) - off_b);
// Delay the normalize of partial gradient by num_positives_sum to back
// propagation because scalar_t reduces precision. Focal loss is very
// sensitive to the small gradient. No worry on overflow here since
// gradient has relative smaller range than input.
loss_acc += loss_t;
*((scalar_t *)(&grad_vec) + j) = static_cast<scalar_t>(grad);
}
// This can't ensure to generate stg.128 and may be two stg.64.
*(uint4 *)&partial_grad[i] = grad_vec;
}
loss_shm[threadIdx.x] = loss_acc;
// Intra-CTA reduction
__syncthreads();
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
loss_shm[threadIdx.x] += loss_shm[threadIdx.x + s];
}
__syncthreads();
}
// Inter-CTA reduction
if (threadIdx.x == 0) {
loss_acc = loss_shm[0] * normalizer;
atomicAdd(loss, loss_acc);
}
}
template <int ILP, typename scalar_t, typename accscalar_t,
typename outscalar_t>
__global__ void focal_loss_backward_cuda_kernel(
scalar_t *partial_grad, const outscalar_t *__restrict__ grad_output,
const float *__restrict__ num_positives_sum, const uint64_t numel) {
int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * ILP;
accscalar_t normalizer = static_cast<accscalar_t>(grad_output[0]) /
static_cast<accscalar_t>(num_positives_sum[0]);
// The input is enforced to pad to use vector load, thus there's no need to
// check whether the last element of ILP can out of bound.
if (idx >= numel)
return;
uint4 grad_vec;
grad_vec = *(uint4 *)&partial_grad[idx];
#pragma unroll(ILP)
for (int i = 0; i < ILP; i++) {
auto grad = static_cast<accscalar_t>(*((scalar_t *)(&grad_vec) + i));
grad *= normalizer;
*((scalar_t *)(&grad_vec) + i) = static_cast<scalar_t>(grad);
}
*(uint4 *)&partial_grad[idx] = grad_vec;
}
std::vector<at::Tensor> focal_loss_forward_cuda(
const at::Tensor &cls_output, const at::Tensor &cls_targets_at_level,
const at::Tensor &num_positives_sum, const int64_t num_real_classes,
const float alpha, const float gamma, const float smoothing_factor) {
// Checks required for correctness
TORCH_INTERNAL_ASSERT(cls_output.size(-1) >= num_real_classes,
"Incorrect number of real classes.");
TORCH_INTERNAL_ASSERT(cls_targets_at_level.scalar_type() == at::kLong,
"Invalid label type.");
TORCH_INTERNAL_ASSERT(
(num_positives_sum.numel() == 1) &&
(num_positives_sum.scalar_type() == at::kFloat),
"Expect num_positives_sum to be a float32 tensor with only one element.");
TORCH_INTERNAL_ASSERT(cls_output.dim() == cls_targets_at_level.dim() + 1,
"Mis-matched dimensions between class output and label.");
for (int64_t i = 0; i < cls_targets_at_level.dim(); i++)
TORCH_INTERNAL_ASSERT(cls_output.size(i) == cls_targets_at_level.size(i),
"Mis-matched shape between class output and label.");
// Checks required for better performance
const int ILP = sizeof(uint4) / cls_output.element_size();
ASSERT_UINT4_ALIGNED(cls_output.data_ptr());
TORCH_INTERNAL_ASSERT(cls_output.size(-1) % ILP == 0,
"Pad number of classes first to take advantage of 128 bit load.");
TORCH_INTERNAL_ASSERT(num_real_classes >= ILP, "Too few classes.");
int64_t num_classes = cls_output.size(-1);
int64_t num_examples = cls_output.numel() / num_classes;
at::Tensor loss = at::zeros({}, cls_output.options().dtype(at::kFloat));
// Compute the incompelete gradient during fprop since most of the heavy
// functions of bprop are the same as fprop, thus trade memory for compute
// helps with focal loss.
at::Tensor partial_grad = at::empty_like(cls_output);
// The grid contains 2 CTA per SM, each CTA loop on input with stride till the
// last item.
cudaDeviceProp props;
cudaGetDeviceProperties(&props, at::cuda::current_device());
dim3 block(512);
dim3 grid(2 * props.multiProcessorCount);
// Specialize on label smoothing or not to reduce redundant operations
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (smoothing_factor == 0.0f) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
cls_output.scalar_type(), "focal_loss_fprop", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
using labelscalar_t = int64_t;
using outscalar_t = float;
const int ILP = sizeof(uint4) / sizeof(scalar_t);
focal_loss_forward_cuda_kernel<false, ILP, scalar_t, labelscalar_t,
accscalar_t, outscalar_t>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
loss.data_ptr<outscalar_t>(),
partial_grad.data_ptr<scalar_t>(),
cls_output.data_ptr<scalar_t>(),
cls_targets_at_level.data_ptr<labelscalar_t>(),
num_positives_sum.data_ptr<float>(), num_examples,
num_classes, num_real_classes, alpha, gamma,
smoothing_factor);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
cls_output.scalar_type(), "focal_loss_fprop", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
using labelscalar_t = int64_t;
using outscalar_t = float;
const int ILP = sizeof(uint4) / sizeof(scalar_t);
focal_loss_forward_cuda_kernel<true, ILP, scalar_t, labelscalar_t,
accscalar_t, outscalar_t>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
loss.data_ptr<outscalar_t>(),
partial_grad.data_ptr<scalar_t>(),
cls_output.data_ptr<scalar_t>(),
cls_targets_at_level.data_ptr<labelscalar_t>(),
num_positives_sum.data_ptr<float>(), num_examples,
num_classes, num_real_classes, alpha, gamma,
smoothing_factor);
});
}
AT_CUDA_CHECK(cudaGetLastError());
return {loss, partial_grad};
}
at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output,
const at::Tensor &partial_grad,
const at::Tensor &num_positives_sum) {
// Each thread process ILP elements
const int ILP = sizeof(uint4) / partial_grad.element_size();
dim3 block(512);
dim3 grid((partial_grad.numel() + block.x * ILP - 1) / (block.x * ILP));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
partial_grad.scalar_type(), "focal_loss_bprop", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
using outscalar_t = float;
const int ILP = sizeof(uint4) / sizeof(scalar_t);
focal_loss_backward_cuda_kernel<ILP, scalar_t, accscalar_t, outscalar_t>
<<<grid, block, 0, stream>>>(partial_grad.data_ptr<scalar_t>(),
grad_output.data_ptr<outscalar_t>(),
num_positives_sum.data_ptr<float>(),
partial_grad.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return partial_grad;
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include "batch_norm.h"
#include <cuda.h>
#include "compat.h"
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static size_t round_up_to_multiple(size_t x, int multiple) {
return ((x + multiple - 1) / multiple) * multiple;
}
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
dataPtr = allocator.allocate(size);
data = dataPtr.get();
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() = default;
size_t size;
void* data;
c10::DataPtr dataPtr;
};
// Return {y}
at::Tensor nhwc_bn_fwd_train(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop) {
auto memory_format = x.suggest_memory_format();
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
const int N = x.size(0);
const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = check_channels_last ? x.size(1) : x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.DATA_PTR<int>();
*magic = (*magic + 1) & 0xff;
// Allocate output tensor
at::Tensor y = check_channels_last ? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr,
y.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return y.contiguous(memory_format);
}
at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon,
const bool fuse_relu) {
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
auto memory_format = x.suggest_memory_format();
const int N = x.size(0);
const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = check_channels_last ? x.size(1) : x.size(3);
// Allocate output tensor
at::Tensor y = check_channels_last ? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr,
y.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(nullptr);
workspace.push_back(nullptr);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwdInference(stream, fuse_relu);
return y.contiguous(memory_format);
}
std::vector<at::Tensor> nhwc_bn_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop) {
// shape
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
auto memory_format = x.suggest_memory_format();
const int N = x.size(0);
const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = check_channels_last ? x.size(1) : x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.DATA_PTR<int>();
*magic = (*magic + 1) & 0xff;
// outputs
at::Tensor x_grad, scale_grad, bias_grad;
// Allocate outputs
x_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x);
scale_grad = at::empty_like(scale);
bias_grad = at::empty_like(bias);
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
x_grad.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr,
dy.contiguous(memory_format).DATA_PTR<at::Half>());
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()},
{scale_grad.DATA_PTR<float>(),
bias_grad.DATA_PTR<float>()});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return std::vector<at::Tensor>{x_grad.contiguous(memory_format), scale_grad, bias_grad};
}
int nhwc_bn_fwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2);
}
int nhwc_bn_bwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2);
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include "dnn.h"
#include <algorithm>
#include <vector>
#include <string>
#include <iostream>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#include "c10/macros/Macros.h"
#define VERBOSE_DEFAULT false
class NhwcBatchNorm {
public:
NhwcBatchNorm() {
name_ = "nhwc_batchnorm";
createTensorDescriptor(&X_tensor_desc_);
createTensorDescriptor(&Y_tensor_desc_);
}
~NhwcBatchNorm() {
destroyTensorDescriptor(X_tensor_desc_);
destroyTensorDescriptor(Y_tensor_desc_);
}
void die() {
std::cerr << "batchnorm not initialized" << std::endl;
exit(-1);
}
void fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void fwdInference(cudaStream_t stream, bool use_relu);
dim3 calc_fwd_grid(int *loop, const int grid_dim_x);
dim3 calc_bwd_grid(int *loop, const int grid_dim_x);
void setInputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w, int bn_group) {
m_ = n * h * w;
int m_bn_adjusted = m_ * bn_group;
c_ = c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_ = 1.f / m_bn_adjusted;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int divisor = m_bn_adjusted - 1;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;
setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);
}
void setOutputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w) {
setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);
}
const std::vector<size_t> numWorkspaceBytes() const;
void setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes);
void setInputOutputPointers(void* X, void* dX, void* Y, void *dY) {
X_ = X;
dX_ = dX;
Y_ = Y;
dY_ = dY;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void setWeightPointers(const std::vector<void*>& weight_pointers,
const std::vector<void*>& deriv_pointers) {
assert(weight_pointers.size() == 2);
assert(deriv_pointers.size() == 2);
scale_ = static_cast<float*>(weight_pointers[0]);
bias_ = static_cast<float*>(weight_pointers[1]);
dscale_ = static_cast<float*>(deriv_pointers[0]);
dbias_ = static_cast<float*>(deriv_pointers[1]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void setParameterPointers(const std::vector<void*>& param_pointers) {
assert(param_pointers.size() == 2);
population_mean_ = static_cast<float*>(param_pointers[0]);
population_variance_ = static_cast<float*>(param_pointers[1]);
}
void setConstants(const double exp_avg_factor, const double eps) {
exp_avg_factor_ = exp_avg_factor;
eps_ = eps;
}
void processCudnnStatus(const dnnStatus_t& status,
const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
#ifdef __HIP_PLATFORM_HCC__
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << miopenGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << miopenGetErrorString(status);
#else
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << cudnnGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudnnGetErrorString(status);
#endif
}
void checkCudaStatus(const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess)
LOG(FATAL) << string << " " << cudaGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudaGetErrorString(status);
}
size_t size_retired_ctas(int grid_y) const {
// Note that the value of max_grid_y to handle known GPUs is about 160.
const int max_grid_y = 1024;
if (grid_y > max_grid_y)
LOG(INFO) << "GPU capabilities exceeds assumptions.";
const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return retired_cta_bytes;
}
dnnTensorDescriptor_t X_tensor_desc_ = nullptr;
dnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
void* X_ = nullptr;
void* dX_ = nullptr;
void* Y_ = nullptr;
void* dY_ = nullptr;
// Learned scale and bias weights.
float* scale_ = nullptr;
float* dscale_ = nullptr;
float* bias_ = nullptr;
float* dbias_ = nullptr;
// Computed population mean and variance parameters.
float* population_mean_ = nullptr;
float* population_variance_ = nullptr;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float* minibatch_mean_ = nullptr;
float* minibatch_variance_ = nullptr;
int m_ = 0; // Number of values per channel that BN is normalizing.
int c_ = 0; // Number of channels over which BN is normalizing.
float svar_inv_count_ = 0.f; // factor to scale sum of squared errors to get saved variance
float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance
double exp_avg_factor_ = 0.;
double eps_ = 0.;
std::string name_;
private:
void setTensorDescriptor(dnnTensorDescriptor_t descriptor,
dnnTensorFormat_t format,
dnnDataType_t data_type,
int n, int c, int h, int w) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenSet4dTensorDescriptor(descriptor, data_type, n, c, h, w);
#else
status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);
#endif
processCudnnStatus(status, "set tensor descriptor");
}
void createTensorDescriptor(dnnTensorDescriptor_t *descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenCreateTensorDescriptor(descriptor);
#else
status = cudnnCreateTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "create tensor_descriptor");
}
void destroyTensorDescriptor(dnnTensorDescriptor_t descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenDestroyTensorDescriptor(descriptor);
#else
status = cudnnDestroyTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "destroy tensor_descriptor");
}
protected:
float *partial_sums_ = nullptr;
int *partial_counts_ = nullptr;
int *retired_ctas_ = nullptr;
void _setFwdParams(NhwcBatchNormFwdParams *params) const;
void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;
void _setBwdParams(NhwcBatchNormBwdParams *params) const;
// @todo: ability to configure these?
// Kernel params
static const int USE_ONLINE_APPROACH = 1;
static const int THREADS_PER_CTA = 512;
static const int THREADS_PER_PIXEL = 32;
static const int C_ELEMENTS_PER_CTA = 128;
static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;
static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;
typedef uint16_t StorageType;
//typedef float StorageType;
// increasing this to 6 causes spills in fwd kernel!
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 1;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 1;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 0;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 0;
static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \
PIXELS_PER_THREAD_IN_SMEM_FWD;
static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \
PIXELS_PER_THREAD_IN_SMEM_BWD;
static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;
// Derived params
static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*sizeof(StorageType);
static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*2*sizeof(StorageType);
static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD;
static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_BWD;
static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD_INFERENCE;
// max grid.y in case of group bn is limited by exchange buffer size
static const int MAX_GBN_BLOCK_Y = 256;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
hipLaunchKernel((void *) fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#else
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1 && use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, true, false, 2, coop);
else
LAUNCH_FWD_KERNEL(1, true, false, 1, coop);
} else if (outer_loops == 1 && !use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, false, false, 2, coop);
else
LAUNCH_FWD_KERNEL(1, false, false, 1, coop);
} else if (use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, true, false, 2, coop);
else
LAUNCH_FWD_KERNEL(0, true, false, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, false, false, 2, coop);
else
LAUNCH_FWD_KERNEL(0, false, false, 1, coop);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_relu_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
#else
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1 && use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_RELU_KERNEL(1, 2, coop);
else
LAUNCH_BWD_RELU_KERNEL(1, 1, coop);
} else if (outer_loops == 1 && !use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_KERNEL(1, 2, coop);
else
LAUNCH_BWD_KERNEL(1, 1, coop);
} else if (use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_RELU_KERNEL(0, 2, coop);
else
LAUNCH_BWD_RELU_KERNEL(0, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_BWD_KERNEL(0, 2, coop);
else
LAUNCH_BWD_KERNEL(0, 1, coop);
}
#undef LAUNCH_BWD_KERNEL
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
};
const std::vector<size_t> NhwcBatchNorm::numWorkspaceBytes() const {
assert(c_ > 0);
// choose the max memory required between fwd/bwd passes
int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);
int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);
int grid_x = max(grid_x_fwd, grid_x_bwd);
int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t num_mean_bytes = c_ * sizeof(float);
const size_t num_variance_bytes = num_mean_bytes;
const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\
ELEMENTS_PER_LDG*2*sizeof(float);
const size_t size_counts = grid_y*grid_x*sizeof(int);
return {num_mean_bytes, num_variance_bytes,
size_retired_ctas(grid_y), size_sums, size_counts};
}
void NhwcBatchNorm::setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes) {
assert(workspace.size() == 5);
assert(num_workspace_bytes.size() == 5);
minibatch_mean_ = static_cast<float*>(workspace[0]);
minibatch_variance_ = static_cast<float*>(workspace[1]);
retired_ctas_ = static_cast<int*>(workspace[2]);
partial_sums_ = static_cast<float*>(workspace[3]);
partial_counts_ = static_cast<int*>(workspace[4]);
}
void NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = nullptr;
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_running_mean = population_mean_;
params->gmem_running_var = population_variance_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->gmem_relu_bitmask = nullptr;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->rvar_inv_count = rvar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_counts = partial_counts_;
params->gmem_retired_ctas = retired_ctas_;
params->var_eps = eps_;
params->outer_loops = 0;
params->exp_avg_factor = static_cast<float>(exp_avg_factor_);
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams
*params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = nullptr;
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_mean = population_mean_;
params->gmem_var = population_variance_;
params->nhw = m_;
params->c = c_;
params->var_eps = eps_;
}
void NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dy = static_cast<uint16_t*>(dY_);
params->gmem_dst = static_cast<uint16_t*>(dX_);
params->gmem_dst1 = nullptr;
params->gmem_relu_bitmask = nullptr;
params->gmem_dscale = dscale_;
params->gmem_dbias = dbias_;
params->gmem_scale = scale_;
params->gmem_bias = bias_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_retired_ctas = retired_ctas_;
params->outer_loops = 0;
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr;
if (!ptrs_are_set)
die();
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);
grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams params;
_setFwdInferenceParams(&params);
if (use_relu) {
nhwc_batch_norm_fwd_inference
<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, true, false>
<<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);
checkCudaStatus(name_ + " fwd_inference-relu kernel");
} else {
nhwc_batch_norm_fwd_inference
<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, false>
<<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);
checkCudaStatus(name_ + " fwd_inference kernel");
}
}
dim3 NhwcBatchNorm::calc_fwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
dim3 NhwcBatchNorm::calc_bwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
void NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr
&& retired_ctas_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams params;
_setFwdParams(&params);
params.my_data = my_data;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);
_fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);
}
void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& (bias_ != nullptr || !use_relu)
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&& X_ != nullptr
&& dX_ != nullptr
// && Y_ != nullptr
&& dY_ != nullptr
&& dscale_ != nullptr
&& dbias_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams params;
_setBwdParams(&params);
params.my_data = my_data;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
params.wgrad_coeff = 1.0 / bn_group;
dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);
_bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include "batch_norm_add_relu.h"
#include <cuda.h>
#include "compat.h"
//FIXME move the common stuff to common h file
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static size_t round_up_to_multiple(size_t x, int multiple) {
return ((x + multiple - 1) / multiple) * multiple;
}
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
dataPtr = allocator.allocate(size);
data = dataPtr.get();
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() = default;
size_t size;
void* data;
c10::DataPtr dataPtr;
};
// Return {y}
at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop) {
auto memory_format = x.suggest_memory_format();
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
const int N = x.size(0);
const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = check_channels_last ? x.size(1) : x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.DATA_PTR<int>();
*magic = (*magic + 1) & 0xff;
// Allocate output tensor
at::Tensor y = check_channels_last? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr,
y.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr,
z.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
workspace.push_back(bitmask.contiguous().DATA_PTR<bitmask_pyt_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return y.contiguous(memory_format);
}
at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon) {
auto memory_format = x.suggest_memory_format();
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
const int N = x.size(0);
const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = check_channels_last ? x.size(1) : x.size(3);
// Allocate output tensor
at::Tensor y = check_channels_last? at::empty({N, C, H, W}, x.options().memory_format(memory_format)): at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr,
y.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr,
z.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(nullptr);
workspace.push_back(nullptr);
workspace.push_back(nullptr);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwdInference(stream);
return y.contiguous(memory_format);
}
std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop) {
// shape
auto memory_format = x.suggest_memory_format();
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
const int N = x.size(0);
const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = check_channels_last ? x.size(1) : x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.DATA_PTR<int>();
*magic = (*magic + 1) & 0xff;
// outputs
at::Tensor x_grad, z_grad, scale_grad, bias_grad;
// Allocate outputs
x_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x);
z_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x);
scale_grad = at::empty_like(scale);
bias_grad = at::empty_like(bias);
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
x_grad.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr,
dy.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr,
z_grad.contiguous(memory_format).DATA_PTR<at::Half>());
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()},
{scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
workspace.push_back(bitmask.contiguous().DATA_PTR<bitmask_pyt_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return std::vector<at::Tensor>{x_grad.contiguous(memory_format), z_grad.contiguous(memory_format), scale_grad, bias_grad};
}
int nhwc_bn_addrelu_fwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2);
}
int nhwc_bn_addrelu_bwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNormAddRelu::smem_driven_bwd_occupancy(device_id, 2);
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_add_relu.h
* \brief CUDA NHWC Batch Normalization code with fused addition
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#include "dnn.h"
#include <algorithm>
#include <vector>
#include <string>
#include <iostream>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#include "c10/macros/Macros.h"
#ifdef __HIP_PLATFORM_HCC__
using bitmask_t = uint64_t;
using bitmask_pyt_t = int64_t;
#else
using bitmask_t = unsigned int;
using bitmask_pyt_t = int32_t;
#endif
#define VERBOSE_DEFAULT false
class NhwcBatchNormAddRelu {
public:
NhwcBatchNormAddRelu() {
name_ = "nhwc_batchnormaddrelu";
createTensorDescriptor(&X_tensor_desc_);
createTensorDescriptor(&Y_tensor_desc_);
}
~NhwcBatchNormAddRelu() {
destroyTensorDescriptor(X_tensor_desc_);
destroyTensorDescriptor(Y_tensor_desc_);
}
void die() {
std::cerr << "batchnormaddrelu not initialized" << std::endl;
exit(-1);
}
void fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void fwdInference(cudaStream_t stream);
dim3 calc_fwd_grid(int *loop, const int grid_dim_x);
dim3 calc_bwd_grid(int *loop, const int grid_dim_x);
void setInputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w, int bn_group) {
m_ = n * h * w;
int m_bn_adjusted = m_ * bn_group;
c_ = c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_ = 1.f / m_bn_adjusted;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int divisor = m_bn_adjusted - 1;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;
setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);
}
void setOutputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w) {
setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);
}
const std::vector<size_t> numWorkspaceBytes() const;
void setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes);
void setInputOutputPointers(void* X, void* dX, void* Y, void *dY, void* addend, void* dAddend) {
X_ = X;
dX_ = dX;
Y_ = Y;
dY_ = dY;
addend_ = addend;
dAddend_ = dAddend;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void setWeightPointers(const std::vector<void*>& weight_pointers,
const std::vector<void*>& deriv_pointers) {
assert(weight_pointers.size() == 2);
assert(deriv_pointers.size() == 2);
scale_ = static_cast<float*>(weight_pointers[0]);
bias_ = static_cast<float*>(weight_pointers[1]);
dscale_ = static_cast<float*>(deriv_pointers[0]);
dbias_ = static_cast<float*>(deriv_pointers[1]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void setParameterPointers(const std::vector<void*>& param_pointers) {
assert(param_pointers.size() == 2);
population_mean_ = static_cast<float*>(param_pointers[0]);
population_variance_ = static_cast<float*>(param_pointers[1]);
}
void setConstants(const double exp_avg_factor, const double eps) {
exp_avg_factor_ = exp_avg_factor;
eps_ = eps;
}
void processCudnnStatus(const dnnStatus_t& status,
const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
#ifdef __HIP_PLATFORM_HCC__
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << miopenGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << miopenGetErrorString(status);
#else
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << cudnnGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudnnGetErrorString(status);
#endif
}
void checkCudaStatus(const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess)
LOG(FATAL) << string << " " << cudaGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudaGetErrorString(status);
}
size_t size_retired_ctas(int grid_y) const {
// Note that the value of max_grid_y to handle known GPUs is about 160.
const int max_grid_y = 1024;
if (grid_y > max_grid_y)
LOG(INFO) << "GPU capabilities exceeds assumptions.";
const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return retired_cta_bytes;
}
dnnTensorDescriptor_t X_tensor_desc_ = nullptr;
dnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
void* X_ = nullptr;
void* dX_ = nullptr;
void* Y_ = nullptr;
void* dY_ = nullptr;
void* addend_ = nullptr;
void* dAddend_ = nullptr;
// Learned scale and bias weights.
float* scale_ = nullptr;
float* dscale_ = nullptr;
float* bias_ = nullptr;
float* dbias_ = nullptr;
// Computed population mean and variance parameters.
float* population_mean_ = nullptr;
float* population_variance_ = nullptr;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float* minibatch_mean_ = nullptr;
float* minibatch_variance_ = nullptr;
int m_ = 0; // Number of values per channel that BN is normalizing.
int c_ = 0; // Number of channels over which BN is normalizing.
float svar_inv_count_ = 0.f; // factor to scale sum of squared errors to get saved variance
float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance
double exp_avg_factor_ = 0.;
double eps_ = 0.;
std::string name_;
private:
void setTensorDescriptor(dnnTensorDescriptor_t descriptor,
dnnTensorFormat_t format,
dnnDataType_t data_type,
int n, int c, int h, int w) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenSet4dTensorDescriptor(descriptor, data_type, n, c, h, w);
#else
status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);
#endif
processCudnnStatus(status, "set tensor descriptor");
}
void createTensorDescriptor(dnnTensorDescriptor_t *descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenCreateTensorDescriptor(descriptor);
#else
status = cudnnCreateTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "create tensor_descriptor");
}
void destroyTensorDescriptor(dnnTensorDescriptor_t descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenDestroyTensorDescriptor(descriptor);
#else
status = cudnnDestroyTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "destroy tensor_descriptor");
}
protected:
float *partial_sums_ = nullptr;
int *partial_counts_ = nullptr;
int *retired_ctas_ = nullptr;
bitmask_t *relu_bitmask_ = nullptr;
void _setFwdParams(NhwcBatchNormFwdParams *params) const;
void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;
void _setBwdParams(NhwcBatchNormBwdParams *params) const;
// @todo: ability to configure these?
// Kernel params
static const int USE_ONLINE_APPROACH = 1;
static const int THREADS_PER_CTA = 512;
static const int THREADS_PER_PIXEL = 32;
static const int C_ELEMENTS_PER_CTA = 128;
static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;
static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;
typedef uint16_t StorageType;
// increasing this to 6 causes spills in fwd kernel!
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 1;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 1;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 0;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 0;
static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \
PIXELS_PER_THREAD_IN_SMEM_FWD;
static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \
PIXELS_PER_THREAD_IN_SMEM_BWD;
static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;
// Derived params
static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*sizeof(StorageType);
static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*2*sizeof(StorageType);
static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD;
static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_BWD;
static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD_INFERENCE;
// max grid.y in case of group bn is limited by exchange buffer size
static const int MAX_GBN_BLOCK_Y = 256;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
hipLaunchKernel((void *) fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#else
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, false, true, 2, coop);
else
LAUNCH_FWD_KERNEL(1, false, true, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, false, true, 2, coop);
else
LAUNCH_FWD_KERNEL(0, false, true, 1, coop);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_add_relu_func, \
hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + \
" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
#else
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_add_relu_func, \
cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + \
" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1) {
if (occupancy >= 2)
LAUNCH_BWD_ADD_RELU_KERNEL(1, 2, coop);
else
LAUNCH_BWD_ADD_RELU_KERNEL(1, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_BWD_ADD_RELU_KERNEL(0, 2, coop);
else
LAUNCH_BWD_ADD_RELU_KERNEL(0, 1, coop);
}
#undef LAUNCH_BWD_KERNEL
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
};
const std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {
assert(c_ > 0);
// choose the max memory required between fwd/bwd passes
int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);
int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);
int grid_x = max(grid_x_fwd, grid_x_bwd);
int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t num_mean_bytes = c_ * sizeof(float);
const size_t num_variance_bytes = num_mean_bytes;
#ifdef __HIP_PLATFORM_HCC__
int elems_per_group = ((m_ + 3) & ~3) * 2;
#else
int elems_per_group = ((m_ + 31) & ~31) * 2;
#endif
int group_count = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t bitmask_bytes = elems_per_group * group_count * sizeof(bitmask_t);
const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\
ELEMENTS_PER_LDG*2*sizeof(float);
const size_t size_counts = grid_y*grid_x*sizeof(int);
return {num_mean_bytes, num_variance_bytes, bitmask_bytes,
size_retired_ctas(grid_y), size_sums, size_counts};
}
void NhwcBatchNormAddRelu::setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes) {
assert(workspace.size() == 6);
assert(num_workspace_bytes.size() == 6);
minibatch_mean_ = static_cast<float*>(workspace[0]);
minibatch_variance_ = static_cast<float*>(workspace[1]);
relu_bitmask_ = static_cast<bitmask_t*>(workspace[2]);
retired_ctas_ = static_cast<int*>(workspace[3]);
partial_sums_ = static_cast<float*>(workspace[4]);
partial_counts_ = static_cast<int*>(workspace[5]);
}
void NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = static_cast<uint16_t*>(addend_);
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_running_mean = population_mean_;
params->gmem_running_var = population_variance_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->gmem_relu_bitmask = relu_bitmask_;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->rvar_inv_count = rvar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_counts = partial_counts_;
params->gmem_retired_ctas = retired_ctas_;
params->var_eps = eps_;
params->outer_loops = 0;
params->exp_avg_factor = static_cast<float>(exp_avg_factor_);
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams
*params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = static_cast<uint16_t*>(addend_);
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_mean = population_mean_;
params->gmem_var = population_variance_;
params->nhw = m_;
params->c = c_;
params->var_eps = eps_;
}
void NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dy = static_cast<uint16_t*>(dY_);
params->gmem_dst = static_cast<uint16_t*>(dX_);
params->gmem_dst1 = static_cast<uint16_t*>(dAddend_);
params->gmem_relu_bitmask = relu_bitmask_;
params->gmem_dscale = dscale_;
params->gmem_dbias = dbias_;
params->gmem_scale = scale_;
params->gmem_bias = bias_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_retired_ctas = retired_ctas_;
params->outer_loops = 0;
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
&& addend_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr;
if (!ptrs_are_set)
die();
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);
grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams params;
_setFwdInferenceParams(&params);
nhwc_batch_norm_fwd_inference
<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, true>
<<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);
checkCudaStatus(name_ + " fwd_inference-relu kernel");
}
dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
dim3 NhwcBatchNormAddRelu::calc_bwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
void NhwcBatchNormAddRelu::fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
&& relu_bitmask_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
&& addend_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr
&& retired_ctas_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams params;
_setFwdParams(&params);
params.my_data = my_data;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);
_fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);
}
void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
&& relu_bitmask_ != nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&& X_ != nullptr
&& dX_ != nullptr
// && Y_ != nullptr
&& dY_ != nullptr
&& dAddend_ != nullptr
&& dscale_ != nullptr
&& dbias_ != nullptr
&& retired_ctas_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams params;
_setBwdParams(&params);
params.my_data = my_data;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
params.wgrad_coeff = 1.0 / bn_group;
dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);
_bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#ifdef __HIP_PLATFORM_HCC__
#include <ATen/hip/HIPContext.h>
#else
#include <ATen/cuda/CUDAContext.h>
#endif
#ifndef CUDA_UTILS_H
#define CUDA_UTILS_H
namespace at {
namespace cuda {
namespace utils {
static inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
#ifdef __HIP_PLATFORM_HCC__
return getDeviceProperties(device_id)->maxSharedMemoryPerMultiProcessor;
#else
return getDeviceProperties(device_id)->sharedMemPerMultiprocessor;
#endif
}
}
}
}
#endif
#ifndef DNN_H
#define DNN_H
#ifdef __HIP_PLATFORM_HCC__
#include <miopen/miopen.h>
#define DNN_STATUS_SUCCESS miopenStatusSuccess
#define DNN_DATA_HALF miopenHalf
#define DNN_TENSOR_FORMAT 0
using dnnTensorFormat_t = int;
using dnnDataType_t = miopenDataType_t;
using dnnStatus_t = miopenStatus_t;
using dnnTensorDescriptor_t = miopenTensorDescriptor_t;
#else
#include <cudnn.h>
#define DNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS
#define DNN_DATA_HALF CUDNN_DATA_HALF
#define DNN_TENSOR_FORMAT CUDNN_TENSOR_NHWC
using dnnTensorFormat_t = cudnnTensorFormat_t;
using dnnDataType_t = cudnnDataType_t;
using dnnStatus_t = cudnnStatus_t;
using dnnTensorDescriptor_t = cudnnTensorDescriptor_t;
#endif
#endif // DNN_H
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ArrayRef.h>
#include <ATen/ScalarType.h>
#include "ATen/Scalar.h"
#ifndef VERSION_GE_1_1
#include "ATen/Type.h"
#endif
#include "ATen/Tensor.h"
#include "ATen/Storage.h"
#include "ATen/Generator.h"
namespace py = pybind11;
int64_t get_buffer_size(
const int bn_sync_steps);
void* get_data_ptr(
const at::Tensor& data);
void* get_remote_data_ptr(
const at::Tensor& handle,
const int64_t offset);
void close_remote_data(
const at::Tensor& handle);
at::Tensor nhwc_bn_fwd_train(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop);
at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon,
const bool fuse_relu);
std::vector<at::Tensor> nhwc_bn_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop);
at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop);
at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon);
std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int occupancy,
const int grid_dim_x,
const bool coop);
int nhwc_bn_fwd_occupancy();
int nhwc_bn_bwd_occupancy();
int nhwc_bn_addrelu_fwd_occupancy();
int nhwc_bn_addrelu_bwd_occupancy();
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_buffer_size", &get_buffer_size, "get_buffer_size");
m.def("get_data_ptr", &get_data_ptr, "get_data_ptr");
m.def("get_remote_data_ptr", &get_remote_data_ptr, "get_remote_data_ptr");
m.def("close_remote_data", &close_remote_data, "close_remote_data");
m.def("bn_fwd_nhwc", &nhwc_bn_fwd_train, "bn_fwd_nhwc");
m.def("bn_fwd_eval_nhwc", &nhwc_bn_fwd_eval, "bn_fwd_eval_nhwc");
m.def("bn_bwd_nhwc", &nhwc_bn_bwd, "bn_bwd_nhwc");
m.def("bn_fwd_nhwc_occupancy", &nhwc_bn_fwd_occupancy, "bn_fwd_nhwc_occupancy");
m.def("bn_bwd_nhwc_occupancy", &nhwc_bn_bwd_occupancy, "bn_bwd_nhwc_occupancy");
m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc");
m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc");
m.def("bn_addrelu_bwd_nhwc", &nhwc_bn_addrelu_bwd, "bn_addrelu_bwd_nhwc");
m.def("bn_addrelu_fwd_nhwc_occupancy", &nhwc_bn_addrelu_fwd_occupancy, "bn_addrelu_fwd_nhwc_occupancy");
m.def("bn_addrelu_bwd_nhwc_occupancy", &nhwc_bn_addrelu_bwd_occupancy, "bn_addrelu_bwd_nhwc_occupancy");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include "compat.h"
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
template<>
struct std::hash<cudaIpcMemHandle_t> {
size_t operator() (const cudaIpcMemHandle_t& handle) const {
size_t hash = 0;
uint8_t* ptr = (uint8_t*)&handle;
assert(sizeof(uint8_t) == 1);
for (int i=0; i<sizeof(cudaIpcMemHandle_t); i++) {
hash += *ptr;
ptr++;
}
return hash;
}
};
template<>
struct std::equal_to<cudaIpcMemHandle_t> {
bool operator() (const cudaIpcMemHandle_t &lhs,
const cudaIpcMemHandle_t &rhs) const {
return (std::memcmp((void*) &lhs,
(void*) &rhs,
sizeof(cudaIpcMemHandle_t)) == 0);
}
};
namespace {
namespace gpuipc {
//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 16;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const int REDUCE_OPS = 4;
// Maximum block.y supported - limited due to buffer allocation
const int MAX_BLOCK_Y = 256;
const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;
const int BYTES_PER_ELEM = 4;
// Buffer size per sync step
const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*2*ELEMENTS_PER_LDG*BYTES_PER_ELEM;
};
class IpcMemHandleRegistry {
public:
void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) {
if (registry_.count(handle) == 0) {
registry_.insert(std::make_pair(handle, RegistryEntry()));
registry_[handle].dev_ptr = ipcOpenMem(handle);
}
registry_[handle].ref_count++;
return (((uint8_t*)registry_[handle].dev_ptr) + offset);
}
void releasePtr(const cudaIpcMemHandle_t& handle) {
if (registry_.count(handle) == 0) {
}
if (--registry_[handle].ref_count == 0) {
ipcCloseMem(registry_[handle].dev_ptr);
registry_.erase(handle);
}
}
struct RegistryEntry {
void* dev_ptr;
int ref_count;
RegistryEntry() : dev_ptr(NULL) , ref_count(0) {}
};
protected:
std::unordered_map<cudaIpcMemHandle_t, RegistryEntry> registry_;
void* ipcOpenMem(const cudaIpcMemHandle_t& handle) {
void *data;
cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess);
cudaCheckErrors("ipc init");
return data;
}
void ipcCloseMem(void* dev_ptr) {
cudaIpcCloseMemHandle(dev_ptr);
cudaCheckErrors("ipc close");
}
};
}
static IpcMemHandleRegistry ipc_mem_registry;
int64_t get_buffer_size(const int bn_sync_steps) {
return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES;
}
void* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) {
cudaIpcMemHandle_t my_handle;
memcpy((unsigned char *)(&my_handle), handle.DATA_PTR<uint8_t>(), sizeof(my_handle));
return ipc_mem_registry.getPtr(my_handle, offset);
}
void close_remote_data(const at::Tensor& handle) {
cudaIpcMemHandle_t my_handle;
memcpy((unsigned char *)(&my_handle), handle.DATA_PTR<uint8_t>(), sizeof(my_handle));
ipc_mem_registry.releasePtr(my_handle);
}
void* get_data_ptr(
const at::Tensor& data) {
return data.DATA_PTR<uint8_t>();
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_kernel.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
#endif
#include <stdint.h>
#include <algorithm>
#ifdef __HIP_PLATFORM_HCC__
using bitmask_t = uint64_t;
#define BITMASK_OFFSET 2
#define ONE_BITMASK 1UL
#else
using bitmask_t = unsigned int;
#define BITMASK_OFFSET 2
#define ONE_BITMASK 1U
#endif
#define DEVICE_FUNCTION static inline __device__
// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN.
#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3
#define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void syncwarp() {
#ifdef __HIP_PLATFORM_HCC__
__builtin_amdgcn_wave_barrier();
#else
__syncwarp();
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
DEVICE_FUNCTION T shfl_sync(T var, int src_lane) {
#ifdef __HIP_PLATFORM_HCC__
return __shfl(var, src_lane);
#else
return __shfl_sync(0xFFFFFFFFU, var, src_lane);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION bitmask_t ballot(int predicate) {
#ifdef __HIP_PLATFORM_HCC__
return __ballot(predicate);
#else
return __ballot_sync(0xFFFFFFFFU, predicate);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename T, int ELEMENTS_PER_LDG >
struct PackedStorage {
enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG };
typedef T Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int ELEMENTS_PER_LDG >
struct PackedStorage<uint16_t, ELEMENTS_PER_LDG> {
enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG/2 };
typedef int Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2*N]) {
// Convert from two f32s to two f16s (mantissa LSB rounds to nearest even)
// (From 64-bit to 32-bit)
half *dst_ = (half *) dst;
#pragma unroll
for (int i = 0; i < N; ++i) {
#ifdef __HIP_PLATFORM_HCC__
dst_[2*i] = __float2half(src[2*i]);
dst_[2*i+1] = __float2half(src[2*i+1]);
#else
uint16_t lo, hi;
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(lo) : "f"(src[2*i+0]));
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(hi) : "f"(src[2*i+1]));
asm volatile("mov.b32 %0, {%1, %2};" : "=r"(dst[i]) : "h"(lo), "h"(hi));
#endif
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = src[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void to_float(float (&dst)[2*N], int (&src)[N]) {
// Convert from two f16s to two f32s (From 32-bit to 64-bit)
#pragma unroll
for (int i = 0; i < N; ++i) {
#ifdef __HIP_PLATFORM_HCC__
half *src_ = (half *) src;
dst[2*i] = __half2float(src_[2*i]);
dst[2*i+1] = __half2float(src_[2*i+1]);
#else
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;" : "=h"(lo), "=h"(hi) : "r"(src[i]));
asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+0]) : "h"(lo));
asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+1]) : "h"(hi));
#endif
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void to_float(float (&dst)[N], float (&src)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = src[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t *gmem) {
dst[0] = __ldg((const int*) gmem);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t *gmem) {
#ifdef __HIP_PLATFORM_HCC__
dst[0] = __ldg((const int*) gmem);
#else
unsigned int tmp;
asm volatile ("ld.global.cs.nc.s32 %0, [%1];" : "=r"(tmp) : "l" ((const uint *)gmem));
dst[0] = tmp;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t *gmem) {
int2 tmp = __ldg((const int2*) gmem);
dst[0] = tmp.x;
dst[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t *gmem) {
#ifdef __HIP_PLATFORM_HCC__
int2 tmp = __ldg((const int2*) gmem);
dst[0] = tmp.x;
dst[1] = tmp.y;
#else
int2 tmp;
asm volatile ("ld.global.cs.nc.v2.s32 {%0,%1}, [%2];"
: "=r"(tmp.x), "=r"(tmp.y) : "l"((const int2 *)gmem));
dst[0] = tmp.x;
dst[1] = tmp.y;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void ldg(float (&dst)[N], const uint16_t *gmem) {
int tmp[N/2];
ldg(tmp, gmem);
to_float(dst, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void ldg_stream(float (&dst)[N], const uint16_t *gmem) {
int tmp[N/2];
ldg_stream(tmp, gmem);
to_float(dst, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[1]) {
reinterpret_cast<int*>(gmem)[0] = src[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[1]) {
#ifdef __HIP_PLATFORM_HCC__
reinterpret_cast<int*>(gmem)[0] = src[0];
#else
unsigned int tmp = src[0];
asm volatile ("st.global.cs.s32 [%0], %1;"
:: "l"((uint *)gmem) , "r"(tmp));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[2]) {
#ifdef __HIP_PLATFORM_HCC__
half *gmem_ = (half *) gmem;
half *src_ = (half *) src;
for (int i = 0; i < 4; i++) {
gmem_[i] = src_[i];
}
#else
reinterpret_cast<int2*>(gmem)[0] = make_int2(src[0], src[1]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[2]) {
#ifdef __HIP_PLATFORM_HCC__
half *gmem_ = (half *) gmem;
half *src_ = (half *) src;
for (int i = 0; i < 4; i++) {
gmem_[i] = src_[i];
}
#else
asm volatile ("st.global.cs.v2.s32 [%0], {%1,%2};"
:: "l"((uint *)gmem) , "r"(src[0]), "r"( src[1]));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[N]) {
int tmp[N/2];
from_float(tmp, src);
stg(gmem, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[N]) {
int tmp[N/2];
from_float(tmp, src);
stg_stream(gmem, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef __HIP_PLATFORM_HCC__
DEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[4]) {
half *gmem_ = (half *) gmem;
gmem_[0] = __float2half(src[0]);
gmem_[1] = __float2half(src[1]);
gmem_[2] = __float2half(src[2]);
gmem_[3] = __float2half(src[3]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[4]) {
half *gmem_ = (half *) gmem;
gmem_[0] = __float2half(src[0]);
gmem_[1] = __float2half(src[1]);
gmem_[2] = __float2half(src[2]);
gmem_[3] = __float2half(src[3]);
}
#endif
DEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float *gmem, int idx) {
#ifdef __HIP_PLATFORM_HCC__
dst[0] = gmem[2*idx];
dst[1] = gmem[2*idx+1];
#else
float2 tmp = __ldg(reinterpret_cast<const float2*>(&gmem[2*idx]));
dst[0] = tmp.x;
dst[1] = tmp.y;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float *gmem, int idx) {
#ifdef __HIP_PLATFORM_HCC__
dst[0] = gmem[4*idx];
dst[1] = gmem[4*idx+1];
dst[2] = gmem[4*idx+2];
dst[3] = gmem[4*idx+3];
#else
float4 tmp = __ldg(reinterpret_cast<const float4*>(&gmem[4*idx]));
dst[0] = tmp.x;
dst[1] = tmp.y;
dst[2] = tmp.z;
dst[3] = tmp.w;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(float (&x)[2], const float *smem, int idx) {
#ifdef __HIP_PLATFORM_HCC__
x[0] = smem[2*idx];
x[1] = smem[2*idx+1];
#else
float2 tmp = *(const float2*) &smem[2*idx];
x[0] = tmp.x;
x[1] = tmp.y;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(int (&x)[1], const int *smem, int idx) {
x[0] = smem[idx];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(float (&x)[4], const float *smem, int idx) {
#ifdef __HIP_PLATFORM_HCC__
x[0] = smem[4*idx];
x[1] = smem[4*idx+1];
x[2] = smem[4*idx+2];
x[3] = smem[4*idx+3];
#else
float4 tmp = *(const float4*) &smem[4*idx];
x[0] = tmp.x;
x[1] = tmp.y;
x[2] = tmp.z;
x[3] = tmp.w;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(int (&x)[2], const int *smem, int idx) {
#ifdef __HIP_PLATFORM_HCC__
x[0] = smem[2*idx];
x[1] = smem[2*idx+1];
#else
int2 tmp = *(const int2*) &smem[2*idx];
x[0] = tmp.x;
x[1] = tmp.y;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[2]) {
#ifdef __HIP_PLATFORM_HCC__
gmem[2*idx] = src[0];
gmem[2*idx+1] = src[1];
#else
reinterpret_cast<float2*>(&gmem[2*idx])[0] = make_float2(src[0], src[1]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[4]) {
#ifdef __HIP_PLATFORM_HCC__
gmem[4*idx] = src[0];
gmem[4*idx+1] = src[1];
gmem[4*idx+2] = src[2];
gmem[4*idx+3] = src[3];
#else
reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0], src[1], src[2], src[3]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void scaled_write_to_gmem(float *gmem, int idx, const float (&src)[4], const float coeff) {
#ifdef __HIP_PLATFORM_HCC__
gmem[4*idx] = src[0]*coeff;
gmem[4*idx+1] = src[1]*coeff;
gmem[4*idx+2] = src[2]*coeff;
gmem[4*idx+3] = src[3]*coeff;
#else
reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0]*coeff, src[1]*coeff, src[2]*coeff, src[3]*coeff);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[2]) {
#ifdef __HIP_PLATFORM_HCC__
smem[2*idx] = x[0];
smem[2*idx+1] = x[1];
#else
reinterpret_cast<float2*>(&smem[2*idx])[0] = make_float2(x[0], x[1]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[1]) {
smem[idx] = x[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[4]) {
#ifdef __HIP_PLATFORM_HCC__
smem[4*idx] = x[0];
smem[4*idx+1] = x[1];
smem[4*idx+2] = x[2];
smem[4*idx+3] = x[3];
#else
reinterpret_cast<float4*>(&smem[4*idx])[0] = make_float4(x[0], x[1], x[2], x[3]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[2]) {
#ifdef __HIP_PLATFORM_HCC__
smem[2*idx] = x[0];
smem[2*idx+1] = x[1];
#else
reinterpret_cast<int2*>(&smem[2*idx])[0] = make_int2(x[0], x[1]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void zero_array(int (&dst)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = 0;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void zero_array(float (&dst)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = 0.f;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void add(float (&x)[N], const float (&y)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] += y[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void multiply(float (&x)[N], const float (&y)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] *= y[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void scale_(float (&x)[N], float scalar) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] *= scalar;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void normalize(float (&x)[N], const float (&bias)[N],
const float (&scale)[N], const float (&m1)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] = bias[i] + scale[i] * (x[i] - m1[i]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Storage>
DEVICE_FUNCTION Storage relu(Storage in) {
Storage zero = (Storage)0.f;
return (in < zero)? zero : in;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_activation(float (&x)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] = relu(x[i]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,
void* params_my_data, void** params_pair_datas, int off,
const int magic,
const int sync_iters) {
// The size of a warp.
#ifdef __HIP_PLATFORM_HCC__
const int THREADS_PER_WARP = 64;
#else
const int THREADS_PER_WARP = 32;
#endif
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 16;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const int REDUCE_OPS = 4;
// Maximum block.y supported - limited due to buffer allocation
const int MAX_BLOCK_Y = 256;
const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
// total size of data per sync iter
const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2;
#ifdef __HIP_PLATFORM_HCC__
for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) {
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
}
#else
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
}
#endif
// The warp leaders, write to SMEM.
if (lane_id < THREADS_PER_PIXEL) {
write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);
}
// The data is in SMEM. Do the final reduction.
__syncthreads();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if (warp_id == 0) {
read_from_smem(x, smem, threadIdx.x);
#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
// Compute the updated sum.
add(x, y);
}
#ifdef __HIP_PLATFORM_HCC__
for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
}
#else
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
}
#endif
// Make sure the data was read from SMEM.
syncwarp();
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
// probably could do it earlier, before sync
#ifndef __HIP_PLATFORM_HCC__ // bn_group > 1 is not enabled on HIP
for (int sync_iter=0; sync_iter < sync_iters; ++sync_iter) {
//float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];
void* params_pair_data = params_pair_datas[sync_iter];
// skip the space consumed by previous sync iterations
const int xbuf_offset = sync_iter*data_total;
// data starts after flags, but have to skip previous
const int data_offset = xbuf_offset
+ off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL*2
+ ELEMENTS_PER_LDG*threadIdx.x*2;
// after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU
if (blockIdx.x == 0) {
volatile float * write_data =
&((reinterpret_cast<float*>(params_pair_data))[data_offset]);
// write the data to memory region to be reflected to other GPU
asm volatile ("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
:: "l"(write_data) , "f"(x[0]), "r"(magic), "f"(x[2]), "r"(magic));
asm volatile ("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
:: "l"(write_data+4) , "f"(x[1]), "r"(magic), "f"(x[3]), "r"(magic));
}
// now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU
volatile float * read_data =
&((reinterpret_cast<float*>(params_my_data))[data_offset]);
float other[4];
uint32_t other_flag_a, other_flag_b;
do {
asm volatile ("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
: "=f"(other[0]), "=r"(other_flag_a), "=f"(other[2]), "=r"(other_flag_b) : "l"(read_data));
} while ((other_flag_a != magic) || (other_flag_b != magic));
do {
asm volatile ("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
: "=f"(other[1]), "=r"(other_flag_a), "=f"(other[3]), "=r"(other_flag_b) : "l"(read_data+4));
} while ((other_flag_a != magic) || (other_flag_b != magic));
add(x, other);
}
#endif
// finally, after syncing up and accounting for partial sums from
// other GPUs as required, write the result
write_to_smem(smem, threadIdx.x, x);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {
// The size of a warp.
#ifdef __HIP_PLATFORM_HCC__
const int THREADS_PER_WARP = 64;
#else
const int THREADS_PER_WARP = 32;
#endif
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 8;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL*2+lane_id);
}
// The warp leaders, write to SMEM.
if (lane_id < THREADS_PER_PIXEL) {
write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);
}
// The data is in SMEM. Do the final reduction.
__syncthreads();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if (warp_id == 0) {
read_from_smem(x, smem, threadIdx.x);
#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
// Compute the updated sum.
add(x, y);
}
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL*2+lane_id);
}
// Make sure the data was read from SMEM.
syncwarp();
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
write_to_smem(smem, threadIdx.x, x);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >
DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
// The size of a warp.
#ifdef __HIP_PLATFORM_HCC__
const int THREADS_PER_WARP = 64;
#else
const int THREADS_PER_WARP = 32;
#endif
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
// total size of data per sync iter
#ifdef __HIP_PLATFORM_HCC__
for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) {
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
}
#else
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
}
#endif
// The warp leaders, write to SMEM.
if (lane_id < THREADS_PER_PIXEL) {
write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);
}
// The data is in SMEM. Do the final reduction.
__syncthreads();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if (warp_id == 0) {
read_from_smem(x, smem, threadIdx.x);
#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
// Compute the updated sum.
add(x, y);
}
#ifdef __HIP_PLATFORM_HCC__
for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) {
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
}
#else
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
}
#endif
// Make sure the data was read from SMEM.
syncwarp();
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
// probably could do it earlier, before sync
write_to_smem(smem, threadIdx.x, x);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >
struct ParallelSums {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
parallel_sums<THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG>(smem, x, nhw);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/*
template<>
struct ParallelSums<16, 4> {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, 0, 0, 0, 0, 0);
}
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const unsigned int& sync_iters) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters);
}
};
template<>
struct ParallelSums<8, 4> {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {
parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw);
}
};
*/
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline int div_up(int m, int n) {
return (m + n - 1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// It is expected that all threads in the CTA enter this function!
DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count, bool master) {
// Register the CTA.
if (threadIdx.x == 0) {
// Issue the membar.
__threadfence();
// Notify that the CTA is done.
int val_to_add = 1;
if (master) {
val_to_add = -(expected_count - 1);
}
atomicAdd(gmem_retired_ctas, val_to_add);
}
// Are all CTAs done?
if (threadIdx.x == 0) {
int retired_ctas = -1;
do {
__threadfence();
#ifdef __HIP_PLATFORM_HCC__
retired_ctas = __ldg((const int*) gmem_retired_ctas);
#else
asm volatile ("ld.global.cg.b32 %0, [%1];"
: "=r"(retired_ctas) : "l"(gmem_retired_ctas));
#endif
} while (retired_ctas != 0);
}
__syncthreads();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct NhwcBatchNormFwdInferenceParams {
// The input/output tensors.
uint16_t *gmem_src, *gmem_dst, *gmem_src1;
// the final mean and variance as calculated during the training process
float *gmem_mean, *gmem_var;
// The bias/scale.
float *gmem_bias, *gmem_scale;
// The dimensions.
int nhw, c;
// epsilon
float var_eps;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int ELEMENTS_PER_LDG,
bool USE_RELU,
bool USE_ADD_RELU
>
__global__ __launch_bounds__(THREADS_PER_CTA)
void nhwc_batch_norm_fwd_inference(NhwcBatchNormFwdInferenceParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// The start position in the NHW dimension where the CTA starts.
const int cta_nhw_stride = gridDim.x * PIXELS_PER_LDG;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// thread's starting point in NHW
const int thread_nhw = thread_in_cta_nhw + blockIdx.x * PIXELS_PER_LDG;
// The position in the C dimension where the CTA starts.
const int cta_c = blockIdx.y * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
float mean[ELEMENTS_PER_LDG], var[ELEMENTS_PER_LDG];
float scale[ELEMENTS_PER_LDG], bias[ELEMENTS_PER_LDG];
zero_array(mean);
zero_array(var);
zero_array(scale);
zero_array(bias);
if (is_valid_c) {
read_from_gmem(var, &params.gmem_var[cta_c], thread_in_cta_c);
read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);
read_from_gmem(mean, &params.gmem_mean[cta_c], thread_in_cta_c);
read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);
}
// Update the scale with the stddev and eps.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
scale[i] *= rsqrtf(var[i] + params.var_eps);
}
// The base pointers for reading/writing
uint16_t *const gmem_src = &params.gmem_src[thread_c];
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
const uint16_t *gmem_src1 = nullptr;
if (USE_ADD_RELU) {
gmem_src1 = &params.gmem_src1[thread_c];
}
// apply BN
for (int nhw = thread_nhw; nhw < params.nhw; nhw += cta_nhw_stride) {
float x_math[ELEMENTS_PER_LDG];
zero_array(x_math);
if (is_valid_c) {
ldg(x_math, &gmem_src[nhw*params.c]);
}
// Normalize and apply activation function
normalize(x_math, bias, scale, mean);
if (USE_ADD_RELU) {
float x1_math[ELEMENTS_PER_LDG];
ldg(x1_math, &gmem_src1[nhw*params.c]);
add(x_math, x1_math);
relu_activation(x_math);
} else if (USE_RELU) {
relu_activation(x_math);
}
if (is_valid_c) {
stg(&gmem_dst[nhw*params.c], x_math);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct NhwcBatchNormFwdParams {
// The input/output tensors.
uint16_t *gmem_src, *gmem_dst, *gmem_src1;
// The bias/scale.
float *gmem_bias, *gmem_scale;
// running mean/var (refer BN API from cudnn doc)
float *gmem_running_mean, *gmem_running_var;
// saved mean/var (refer BN API from cudnn doc)
float *gmem_saved_mean, *gmem_saved_var;
// ReLU bitmask
bitmask_t *gmem_relu_bitmask;
// The dimensions.
int nhw, c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float svar_inv_count;
// factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1).
float rvar_inv_count;
// The buffer to do the reduction for mean, stddev and count.
float *gmem_sums;
// The buffer to count items in the different CTAs.
int *gmem_counts;
// The counters of retired CTAs.
int *gmem_retired_ctas;
// The epsilon to apply to the computation of the variance.
float var_eps;
// outer loop count
int outer_loops;
// exponential average factor
float exp_avg_factor;
// number of CTAs along .x dimension
int c_blks;
void* my_data;
void* pair_datas[4];
int magic;
int sync_iters;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
bool USE_RELU,
bool USE_ADD_RELU,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_fwd(NhwcBatchNormFwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG];
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
#ifdef __HIP_PLATFORM_HCC__
const half zero_h = __float2half(0.0F);
#endif
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
// Clamp thread_c so that we load from valid locations even if we don't use the value
if (!is_valid_c)
thread_c = params.c - 4;
// Single pass numerically stable algorithm, see:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
//
// n = 0, mean = 0.0, M2 = 0.0
//
// for x in data:
// n += 1
// delta = x - mean
// mean += delta/n
// delta2 = x - mean
// M2 += delta*delta2
//
// if n < 2:
// return float('nan')
// else:
// return M2 / (n - 1)
// Register to store the number of elements read so far.
float count = 0.f, mean[ELEMENTS_PER_LDG], m2[ELEMENTS_PER_LDG];
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
mean[i] = 0.f;
m2[i] = 0.f;
}
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointer to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute the mean/var across those elements.
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int offset = (pixels_per_iteration * OUTER_LOOPS +
PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31;
cta_nhw_regs -= offset;
cta_nhw_smem -= offset;
}
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(min(nhw_regs + PIXELS_PER_CTA_IN_REGISTERS, params.nhw) -
max(nhw_regs, 0), 0);
// Load the data and compute the local mean/sum and the variance.
if (USE_ONLINE_APPROACH) {
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
is_valid[i] = 0.f;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
#ifndef __HIP_PLATFORM_HCC__
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
} else {
#endif
ldg(x_storage[i], &gmem_src[idx*params.c]);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
is_valid[i] = 1.f;
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
float delta0 = x_math[j] - mean[j];
mean[j] += delta0 * inv_count;
float delta1 = x_math[j] - mean[j];
m2[j] += delta0 * delta1 * is_valid[i];
}
}
} else {
// Read the elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
}
count += 1.f;
}
}
// Sum the elements in registers.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
mean[j] += x_math[j];
}
}
// Compute the mean.
float inv_count = 1.f / count;
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
mean[j] *= inv_count;
}
// Compute the variance.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Is it a valid pixel?
float is_valid = i < static_cast<int>(count) ? 1.f : 0.f;
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
m2[j] += (x_math[j] - mean[j]) * (x_math[j] - mean[j]) * is_valid;
}
}
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(smem_nhw + PIXELS_PER_CTA_IN_SMEM, params.nhw) - max(smem_nhw, 0);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
float is_pixel_valid = (((unsigned int)idx <
(unsigned int)params.nhw) && is_valid_c) ? 1.f : 0.f;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];
ldg_stream(x_storage_local, &gmem_src[(is_pixel_valid ? idx : 0)*params.c]);
// The offset to store in SMEM.
const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
float delta0 = x_math[j] - mean[j];
mean[j] += delta0 * inv_count;
float delta1 = x_math[j] - mean[j];
m2[j] += delta0 * delta1 * is_pixel_valid;
}
}
}
// We scale the mean by the number of elements. It brings more stability.
float m1[ELEMENTS_PER_LDG];
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m1[i] = mean[i] * count;
}
// Run the parallel sum accross the CTA to get the local sum.
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, m1, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(m1, smem, thread_in_cta_c);
__syncthreads();
// Adjust the variance.
float inv_cta_count = 1.f / static_cast<float>(cta_count);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
float mean_diff = m1[i]*inv_cta_count - mean[i];
m2[i] = m2[i] + mean_diff * mean_diff * count;
}
// Run the parallel sum accross the CTA to get the local adjusted variance.
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, m2, thread_in_cta_nhw);
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, m1);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, m2);
}
// The memory location to store the number of pixels per CTA.
int *gmem_counts = &params.gmem_counts[c_blk_index*gridDim.x];
if (threadIdx.x == 0) {
gmem_counts[blockIdx.x] = cta_count;
}
// Read the bias and scale.
float bias[ELEMENTS_PER_LDG], scale[ELEMENTS_PER_LDG];
if (is_valid_c) {
read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);
read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the mean to compute the global mean.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m1[i] = 0.f;
}
// Build the global mean.
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp[ELEMENTS_PER_LDG];
read_from_gmem(tmp, gmem_sums, idx);
add(m1, tmp);
}
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0)
{
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+3, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, m1, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(m1, smem, thread_in_cta_c);
__syncthreads();
// Normalize the mean.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m1[i] = m1[i] * params.svar_inv_count;
}
// Reset the variance.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m2[i] = 0.f;
}
// for add+relu fusion
const uint16_t *gmem_src1 = nullptr;
if (USE_ADD_RELU) {
gmem_src1 = &params.gmem_src1[thread_c];
}
// Build the global variance.
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
// Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration.
float tmp_mean[ELEMENTS_PER_LDG], tmp_var[ELEMENTS_PER_LDG];
read_from_gmem(tmp_mean, &gmem_sums[ 0], idx);
read_from_gmem(tmp_var, &gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx);
// Read the number of pixels visited by a given CTA.
cta_count = __ldg(&gmem_counts[idx / THREADS_PER_PIXEL]);
// Compute the diff to update the variance.
float mean_diff[ELEMENTS_PER_LDG], inv_cta_count = 1.f / static_cast<float>(cta_count);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
mean_diff[i] = m1[i] - tmp_mean[i]*inv_cta_count;
}
// Update the variance.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m2[i] += tmp_var[i] + mean_diff[i]*mean_diff[i]*static_cast<float>(cta_count);
}
}
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0)
{
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+2, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, m2, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
read_from_smem(m2, smem, thread_in_cta_c);
// Finalize the stddev.
// becasue saved var and running var may have different denominator, we don't do it here
// scale_(m2, inv_count);
// store the saved mean/var
float svarinv[ELEMENTS_PER_LDG];
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
svarinv[i] = rsqrtf(m2[i] * params.svar_inv_count + params.var_eps);
}
if (is_valid_for_saving) {
write_to_gmem(params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG, m1);
write_to_gmem(params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG, svarinv);
}
// store the running mean/var
float rmean[ELEMENTS_PER_LDG], rvar[ELEMENTS_PER_LDG];
zero_array(rmean);
zero_array(rvar);
if (params.exp_avg_factor != 1.f && is_valid_for_saving) {
read_from_gmem(rmean, params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG);
read_from_gmem(rvar, params.gmem_running_var, thread_c/ELEMENTS_PER_LDG);
}
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
rmean[i] = (1.f - params.exp_avg_factor) * rmean[i] + \
params.exp_avg_factor * m1[i];
rvar[i] = (1.f - params.exp_avg_factor) * rvar[i] + \
params.exp_avg_factor * (m2[i] * params.rvar_inv_count);
}
if (is_valid_for_saving) {
write_to_gmem(params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG, rmean);
write_to_gmem(params.gmem_running_var, thread_c/ELEMENTS_PER_LDG, rvar);
}
// Update the scale with the stddev and eps.
multiply(scale, svarinv);
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask +
#ifdef __HIP_PLATFORM_HCC__
((params.nhw + 3) & ~3) * 2 * c_blk_index;
#else
((params.nhw + 31) & ~31) * 2 * c_blk_index;
#endif
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
const bool is_valid = is_valid_nhw && is_valid_c;
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Normalize and apply activation function
normalize(x_math, bias, scale, m1);
if (USE_ADD_RELU) {
float x1_math[ELEMENTS_PER_LDG];
ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);
add(x_math, x1_math);
bitmask_t relu_mask;
#ifdef __HIP_PLATFORM_HCC__
int lane_id = threadIdx.x & 63;
#else
int lane_id = threadIdx.x & 31;
#endif
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
#ifdef __HIP_PLATFORM_HCC__
bool rectified = __hle(__float2half(x_math[j]), zero_h);
#else
bool rectified = x_math[j] < 0;
#endif
bitmask_t local_relu_mask = ballot(rectified);
if (lane_id == j) {
// Thread 0 remembers the relu_mask from the first time through this
// loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last.
relu_mask = local_relu_mask;
}
if (rectified) {
x_math[j] = 0.0F;
}
}
if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {
gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id] = relu_mask;
}
} else if (USE_RELU) {
relu_activation(x_math);
}
// Write back.
if (is_valid) {
stg_stream(&gmem_dst[idx*params.c], x_math);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
#pragma unroll 2
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
const bool is_valid = is_valid_nhw && is_valid_c;
// Read from SMEM.
const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
// Normalize and apply activation function
normalize(x_math, bias, scale, m1);
if (USE_ADD_RELU) {
float x1_math[ELEMENTS_PER_LDG];
ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);
add(x_math, x1_math);
bitmask_t relu_mask;
#ifdef __HIP_PLATFORM_HCC__
int lane_id = threadIdx.x & 63;
#else
int lane_id = threadIdx.x & 31;
#endif
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
#ifdef __HIP_PLATFORM_HCC__
bool rectified = __hle(__float2half(x_math[j]), zero_h);
#else
bool rectified = x_math[j] < 0;
#endif
bitmask_t local_relu_mask = ballot(rectified);
if (lane_id == j) {
relu_mask = local_relu_mask;
}
if (rectified) {
x_math[j] = 0.0F;
}
}
if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {
gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id] = relu_mask;
}
} else if (USE_RELU) {
relu_activation(x_math);
}
// Write back.
if (is_valid) {
stg_stream(&gmem_dst[idx*params.c], x_math);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct NhwcBatchNormBwdParams {
// The input/output tensors.
uint16_t *gmem_src, *gmem_dy, *gmem_dst, *gmem_dst1;
// dscale/dbias
float *gmem_dscale, *gmem_dbias;
// The scale and bias.
float *gmem_scale, *gmem_bias;
// The mean/inv-var saved from fwd pass
float *gmem_saved_mean, *gmem_saved_var;
// ReLU bitmask
bitmask_t *gmem_relu_bitmask;
// The dimensions.
int nhw, c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float svar_inv_count;
// The buffer to do the reduction for dscale and dbias
float *gmem_sums;
// The counters of retired CTAs.
int *gmem_retired_ctas;
// outer loop count
int outer_loops;
// number of CTAs along .x dimension
int c_blks;
void* my_data;
void* pair_datas[4];
int magic;
int sync_iters;
float wgrad_coeff;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&x)[N],
const float (&mean_var_scale_bias)[N],
const float (&var_scale)[N], bool valid_data) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];
if ((y <= 0.f) && valid_data) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&y)[N], bool valid_data) {
#pragma unroll
for (int j = 0; j < N; ++j) {
if ((y[j] <= 0.f) && valid_data) {
dy[j] = 0.f;
}
}
}
template <int N>
DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const bool (&rectified)[N], bool valid_data) {
#pragma unroll
for (int j = 0; j < N; ++j) {
if (rectified[j] && valid_data) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N],
const float (&x)[N],
const float (&mean_var_scale_bias)[N],
const float (&var_scale)[N]) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];
if (y <= 0.f) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&y)[N]) {
#pragma unroll
for (int j = 0; j < N; ++j) {
if (y[j] <= 0.f) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void bwd_update(float (&dscale)[N], float (&dbias)[N],
const float (&dy)[N], const float (&x)[N],
const float (&mean)[N], float inv_count) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float delta0 = dy[j] - dbias[j];
dbias[j] += delta0 * inv_count;
delta0 = (dy[j] * (x[j] - mean[j])) - dscale[j];
dscale[j] += delta0 * inv_count;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void bwd_dx(float (&dx)[N], const float (&dy)[N],
const float (&var)[N], const float (&x)[N], const float (&mean)[N],
const float (&dscale)[N], const float (&dbias)[N], float inv_count) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float tmp1 = dy[j] - (dbias[j]* inv_count);
float tmp2 = dscale[j] * inv_count;
float tmp3 = x[j] - mean[j];
dx[j] = var[j] * (tmp1 - (tmp2 * tmp3));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
// Registers to store the mean used for entire duration
float mean[ELEMENTS_PER_LDG];
zero_array(mean);
if (is_valid_c) {
read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);
}
// accumulation related registers
float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];
zero_array(dscale);
zero_array(dbias);
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointers to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
const uint16_t *gmem_dy = &params.gmem_dy[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute sum across them
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS -
PIXELS_PER_CTA_IN_SMEM * gridDim.x;
cta_nhw_regs += offset;
cta_nhw_smem += offset;
}
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
zero_array(dy_storage[i]);
is_valid[i] = 0.f;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
ldg(dy_storage[i], &gmem_dy[idx*params.c]);
}
is_valid[i] = 1.f;
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float and update
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
bool is_pixel_valid = (((unsigned int)idx <
(unsigned int)params.nhw) && is_valid_c);
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
zero_array(x_storage_local);
zero_array(dy_storage_local);
if (is_pixel_valid) {
ldg_stream(x_storage_local, &gmem_src[idx*params.c]);
ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);
}
// The offset to store in SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dbias[i] *= count;
dscale[i] *= count;
}
// dscale parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
__syncthreads();
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, dscale);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the accumulators for global summation
zero_array(dscale);
zero_array(dbias);
// Build the global accumulation
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];
read_from_gmem(tmp1, gmem_sums, idx);
read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dscale[i] += tmp1[i];
dbias[i] += tmp2[i];
}
}
// dscale parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
// inv-var
float var[ELEMENTS_PER_LDG];
zero_array(var);
if (is_valid_c) {
read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
// Normalize the dscale.
multiply(dscale, var);
// store dscale/dbias
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
if (is_valid_for_saving) {
if (params.sync_iters>0)
{
scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);
scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);
} else {
write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);
write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);
}
}
// scale
float scale[ELEMENTS_PER_LDG];
zero_array(scale);
if (is_valid_c) {
read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
// Further normalize the dscale to be used in dx calculation
multiply(dscale, var);
// scale the inv-var as well, afterwards
multiply(var, scale);
// inverse count
float inv_count = params.svar_inv_count;
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
if (is_valid) {
// Read from SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
// Registers to store the mean/var/scale/bias used for the entire duration
// Register usage optimizations:
// 1. Can combine bias - (mean * var * scale) into a single register
// 2. Can combine var * scale into a single register
float varscale[ELEMENTS_PER_LDG];
zero_array(varscale);
if (is_valid_c) {
read_from_gmem(varscale, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
float tmp[ELEMENTS_PER_LDG];
zero_array(tmp);
if (is_valid_c) {
read_from_gmem(tmp, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
multiply(varscale, tmp);
float mean[ELEMENTS_PER_LDG];
zero_array(mean);
if (is_valid_c) {
read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);
}
zero_array(tmp);
if (is_valid_c) {
read_from_gmem(tmp, params.gmem_bias, thread_c/ELEMENTS_PER_LDG);
}
float mean_var_scale_bias[ELEMENTS_PER_LDG];
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
mean_var_scale_bias[i] = tmp[i] - (mean[i] * varscale[i]);
}
// accumulation related registers
float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];
zero_array(dscale);
zero_array(dbias);
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointers to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
const uint16_t *gmem_dy = &params.gmem_dy[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute sum across them
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS -
PIXELS_PER_CTA_IN_SMEM * gridDim.x;
cta_nhw_regs += offset;
cta_nhw_smem += offset;
}
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
zero_array(dy_storage[i]);
is_valid[i] = 0.f;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
ldg(dy_storage[i], &gmem_dy[idx*params.c]);
}
is_valid[i] = 1.f;
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float and update
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_valid[i]);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
bool is_pixel_valid = (((unsigned int)idx <
(unsigned int)params.nhw) && is_valid_c);
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
zero_array(x_storage_local);
zero_array(dy_storage_local);
if (is_pixel_valid) {
ldg_stream(x_storage_local, &gmem_src[idx*params.c]);
ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);
}
// The offset to store in SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_pixel_valid);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dbias[i] *= count;
dscale[i] *= count;
}
// dscale parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
__syncthreads();
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, dscale);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the accumulators for global summation
zero_array(dscale);
zero_array(dbias);
// Build the global accumulation
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];
read_from_gmem(tmp1, gmem_sums, idx);
read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dscale[i] += tmp1[i];
dbias[i] += tmp2[i];
}
}
// dscale parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
// Normalize the dscale.
float var[ELEMENTS_PER_LDG];
zero_array(var);
if (is_valid_c) {
read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// store dscale/dbias
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
if (is_valid_for_saving) {
if (params.sync_iters>0)
{
scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);
scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);
} else {
write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);
write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);
}
}
// Further normalize the dscale to be used in dx calculation
float scale[ELEMENTS_PER_LDG];
zero_array(scale);
if (is_valid_c) {
read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// scale the inv-var as well, afterwards
multiply(var, scale);
// inverse count
float inv_count = params.svar_inv_count;
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
if (is_valid) {
// Read from SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
float mean[ELEMENTS_PER_LDG];
zero_array(mean);
if (is_valid_c) {
read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);
}
// accumulation related registers
float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];
zero_array(dscale);
zero_array(dbias);
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointers to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
const uint16_t *gmem_dy = &params.gmem_dy[thread_c];
uint16_t *gmem_dst1 = &params.gmem_dst1[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute sum across them
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x -
params.nhw) & ~31;
cta_nhw_regs -= offset;
cta_nhw_smem -= offset;
}
const bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask +
#ifdef __HIP_PLATFORM_HCC__
((params.nhw + 3) & ~3) * 2 * c_blk_index;
#else
((params.nhw + 31) & ~31) * 2 * c_blk_index;
#endif
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));
#ifdef __HIP_PLATFORM_HCC__
int lane_id = threadIdx.x & 63;
#else
int lane_id = threadIdx.x & 31;
#endif
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
bitmask_t relu_mask[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
zero_array(dy_storage[i]);
is_valid[i] = 0.f;
const bool is_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
if (is_valid_nhw) {
if (is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
ldg(dy_storage[i], &gmem_dy[idx*params.c]);
}
is_valid[i] = 1.f;
}
if (lane_id < ELEMENTS_PER_LDG) {
relu_mask[i] = gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id];
}
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
// Convert to float and update
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
bool rectified[ELEMENTS_PER_LDG];
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
rectified[j] = ((shfl_sync(relu_mask[i], j) &
(ONE_BITMASK << lane_id)) != 0);
}
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
relu_bwd(dy_math, rectified, is_valid[i]);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
// Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version
from_float(dy_storage[i], dy_math);
// dZ for elementwise add
if (is_valid[i]) {
if (loop_i == OUTER_LOOPS - 1) {
stg_stream(&gmem_dst1[idx*params.c], dy_storage[i]);
} else {
stg(&gmem_dst1[idx*params.c], dy_storage[i]);
}
}
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_pixel_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
bitmask_t relu_mask;
#ifdef __HIP_PLATFORM_HCC__
int lane_id = threadIdx.x & 63;
#else
int lane_id = threadIdx.x & 31;
#endif
zero_array(x_storage_local);
zero_array(dy_storage_local);
if (is_pixel_valid_nhw) {
if (is_valid_c) {
ldg_stream(x_storage_local, &gmem_src[idx*params.c]);
ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);
}
if (lane_id < ELEMENTS_PER_LDG) {
relu_mask = gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id];
}
}
bool rectified[ELEMENTS_PER_LDG];
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
rectified[j] = ((shfl_sync(relu_mask, j) &
(ONE_BITMASK << lane_id)) != 0);
}
// The offset to store in SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
relu_bwd(dy_math, rectified, is_pixel_valid);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
from_float(dy_storage_local, dy_math);
// dZ for elementwise add
if (is_pixel_valid) {
stg_stream(&gmem_dst1[idx*params.c], dy_storage_local);
}
// only store the 'relu-dgrad'ed version!
write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dbias[i] *= count;
dscale[i] *= count;
}
// dscale parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
__syncthreads();
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, dscale);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the accumulators for global summation
zero_array(dscale);
zero_array(dbias);
// Build the global accumulation
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];
read_from_gmem(tmp1, gmem_sums, idx);
read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dscale[i] += tmp1[i];
dbias[i] += tmp2[i];
}
}
// dscale parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
// Normalize the dscale.
float var[ELEMENTS_PER_LDG];
zero_array(var);
if (is_valid_c) {
read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// store dscale/dbias
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
if (is_valid_for_saving) {
if (params.sync_iters>0)
{
scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);
scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);
} else {
write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);
write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);
}
}
// Further normalize the dscale to be used in dx calculation
float scale[ELEMENTS_PER_LDG];
zero_array(scale);
if (is_valid_c) {
read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// scale the inv-var as well, afterwards
multiply(var, scale);
// inverse count
float inv_count = params.svar_inv_count;
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
// Convert to float.
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
if (is_valid) {
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
float y[ELEMENTS_PER_LDG];
zero_array(y);
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dst1[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
if (is_valid) {
// Read from SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#include <torch/torch.h>
#include <vector>
#include <cstdint>
void index_mul_2d_float_foward_cuda(at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
void index_mul_2d_half_foward_cuda(at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
void index_mul_2d_float_forward(
at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_float_foward_cuda(out, in1, in2, idx1);
}
void index_mul_2d_float_backward(
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_float_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1);
}
void index_mul_2d_float_backwrad_backward(
at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_float_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1);
}
void index_mul_2d_half_forward(
at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_half_foward_cuda(out, in1, in2, idx1);
}
void index_mul_2d_half_backward(
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_half_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1);
}
void index_mul_2d_half_backwrad_backward(
at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_half_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("float_forward", &index_mul_2d_float_forward,
"index mul float calculation forward (CUDA)");
m.def("float_backward", &index_mul_2d_float_backward,
"index mul float calculation backward (CUDA)");
m.def("float_backward_backward", &index_mul_2d_float_backwrad_backward,
"index mul float calculation backward backward (CUDA)");
m.def("half_forward", &index_mul_2d_half_forward,
"index mul half calculation forward (CUDA)");
m.def("half_backward", &index_mul_2d_half_backward,
"index mul half calculation backward (CUDA)");
m.def("half_backward_backward", &index_mul_2d_half_backwrad_backward,
"index mul half calculation backward backward (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#ifdef ATEN_ATOMIC_HEADER
#include <ATen/cuda/Atomic.cuh>
#else
#include <THC/THCAtomics.cuh>
#endif
__global__ void index_mul_2d_float_dim64(
float *out,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
constexpr int fea_dim = 64;
if (start_idx < size) {
int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx;
int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx;
float4 res, src1, src2;
src1 = reinterpret_cast<const float4 *>(in1)[vec_idx1];
src2 = reinterpret_cast<const float4 *>(in2)[vec_idx2];
res.x = src1.x * src2.x;
res.y = src1.y * src2.y;
res.z = src1.z * src2.z;
res.w = src1.w * src2.w;
reinterpret_cast<float4 *>(out)[vec_idx2] = res;
}
}
__global__ void index_mul_2d_float(
float *out,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = (idx1[start_idx] * fea_dim);
int64_t vec_idx2 = (start_idx * fea_dim);
for (int i = tidx; i < fea_dim; i += stride) {
out[vec_idx2 + i] = in1[vec_idx1 + i] * in2[vec_idx2 + i];
}
}
}
__global__ void index_mul_2d_half(
at::Half *out,
const at::Half *in1,
const at::Half *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = (idx1[start_idx] * fea_dim);
int64_t vec_idx2 = (start_idx * fea_dim);
for (int i = tidx; i < fea_dim; i += stride) {
out[vec_idx2 + i] = at::Half(static_cast<float>(in1[vec_idx1 + i]) * static_cast<float>(in2[vec_idx2 + i]));
}
}
}
__global__ void index_mul_2d_grad_float_dim64(
float *grad_in1,
float *grad_in2,
const float *grad_out,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
constexpr int fea_dim = 64;
if (start_idx < size) {
int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx;
int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx;
float4 src_in1, src_in2, src_grad_out, dst_grad_in2;
src_grad_out = reinterpret_cast<const float4 *>(grad_out)[vec_idx2];
src_in1 = reinterpret_cast<const float4 *>(in1)[vec_idx1];
src_in2 = reinterpret_cast<const float4 *>(in2)[vec_idx2];
int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4;
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_out.x * src_in2.x);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_out.y * src_in2.y);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_out.z * src_in2.z);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_out.w * src_in2.w);
dst_grad_in2.x = src_grad_out.x * src_in1.x;
dst_grad_in2.y = src_grad_out.y * src_in1.y;
dst_grad_in2.z = src_grad_out.z * src_in1.z;
dst_grad_in2.w = src_grad_out.w * src_in1.w;
reinterpret_cast<float4 *>(grad_in2)[vec_idx2] = dst_grad_in2;
}
}
__global__ void index_mul_2d_grad_float(
float *grad_in1,
float *grad_in2,
const float *grad_out,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = idx1[start_idx] * fea_dim;
int64_t vec_idx2 = start_idx * fea_dim;
for (int i = tidx; i < fea_dim; i += stride) {
float src_in1 = in1[vec_idx1 + i];
float src_in2 = in2[vec_idx2 + i];
float src_grad_out = grad_out[vec_idx2 + i];
grad_in2[vec_idx2 + i] = src_grad_out * src_in1;
gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_out * src_in2);
}
}
}
__global__ void index_mul_2d_grad_half(
at::Half *grad_in1,
at::Half *grad_in2,
const at::Half *grad_out,
const at::Half *in1,
const at::Half *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = idx1[start_idx] * fea_dim;
int64_t vec_idx2 = start_idx * fea_dim;
for (int i = tidx; i < fea_dim; i += stride) {
float src_in1 = static_cast<float>(in1[vec_idx1 + i]);
float src_in2 = static_cast<float>(in2[vec_idx2 + i]);
float src_grad_out = static_cast<float>(grad_out[vec_idx2 + i]);
grad_in2[vec_idx2 + i] = at::Half(src_grad_out * src_in1);
gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_out * src_in2));
}
}
}
__global__ void index_mul_2d_grad_grad_float_dim64(
float *grad_grad_out,
float *grad_in1,
float *grad_in2,
const float *grad_out,
const float *grad_grad_in1,
const float *grad_grad_in2,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
constexpr int fea_dim = 64;
if (start_idx < size) {
int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx;
int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx;
float4 src_grad_grad_in1, src_in1, src_grad_grad_in2, src_in2, src_grad_out;
float4 dst_grad_grad_out, dst_grad_in2;
src_grad_grad_in1 = reinterpret_cast<const float4 *>(grad_grad_in1)[vec_idx1];
src_in1 = reinterpret_cast<const float4 *>(in1)[vec_idx1];
src_grad_grad_in2 = reinterpret_cast<const float4 *>(grad_grad_in2)[vec_idx2];
src_in2 = reinterpret_cast<const float4 *>(in2)[vec_idx2];
dst_grad_grad_out.x = src_grad_grad_in1.x * src_in2.x + src_grad_grad_in2.x * src_in1.x;
dst_grad_grad_out.y = src_grad_grad_in1.y * src_in2.y + src_grad_grad_in2.y * src_in1.y;
dst_grad_grad_out.z = src_grad_grad_in1.z * src_in2.z + src_grad_grad_in2.z * src_in1.z;
dst_grad_grad_out.w = src_grad_grad_in1.w * src_in2.w + src_grad_grad_in2.w * src_in1.w;
reinterpret_cast<float4 *>(grad_grad_out)[vec_idx2] = dst_grad_grad_out;
src_grad_out = reinterpret_cast<const float4 *>(grad_out)[vec_idx2];
int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4;
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_grad_in2.x * src_grad_out.x);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_grad_in2.y * src_grad_out.y);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_grad_in2.z * src_grad_out.z);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_grad_in2.w * src_grad_out.w);
dst_grad_in2.x = src_grad_grad_in1.x * src_grad_out.x;
dst_grad_in2.y = src_grad_grad_in1.y * src_grad_out.y;
dst_grad_in2.z = src_grad_grad_in1.z * src_grad_out.z;
dst_grad_in2.w = src_grad_grad_in1.w * src_grad_out.w;
reinterpret_cast<float4 *>(grad_in2)[vec_idx2] = dst_grad_in2;
}
}
__global__ void index_mul_2d_grad_grad_float(
float *grad_grad_out,
float *grad_in1,
float *grad_in2,
const float *grad_out,
const float *grad_grad_in1,
const float *grad_grad_in2,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = idx1[start_idx] * fea_dim;
int64_t vec_idx2 = start_idx * fea_dim;
for (int i = tidx; i < fea_dim; i += stride) {
float src_grad_grad_in1 = grad_grad_in1[vec_idx1 + i];
float src_grad_grad_in2 = grad_grad_in2[vec_idx2 + i];
float src_in1 = in1[vec_idx1 + i];
float src_in2 = in2[vec_idx2 + i];
float src_grad_out = grad_out[vec_idx2 + i];
grad_grad_out[vec_idx2 + i] = src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1;
grad_in2[vec_idx2 + i] = src_grad_grad_in1 * src_grad_out;
gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_grad_in2 * src_grad_out);
}
}
}
__global__ void index_mul_2d_grad_grad_half(
at::Half *grad_grad_out,
at::Half *grad_in1,
at::Half *grad_in2,
const at::Half *grad_out,
const at::Half *grad_grad_in1,
const at::Half *grad_grad_in2,
const at::Half *in1,
const at::Half *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = idx1[start_idx] * fea_dim;
int64_t vec_idx2 = start_idx * fea_dim;
for (int i = tidx; i < fea_dim; i += stride) {
float src_grad_grad_in1 = static_cast<float>(grad_grad_in1[vec_idx1 + i]);
float src_grad_grad_in2 = static_cast<float>(grad_grad_in2[vec_idx2 + i]);
float src_in1 = static_cast<float>(in1[vec_idx1 + i]);
float src_in2 = static_cast<float>(in2[vec_idx2 + i]);
float src_grad_out = static_cast<float>(grad_out[vec_idx2 + i]);
grad_grad_out[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1);
grad_in2[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_grad_out);
gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_grad_in2 * src_grad_out));
}
}
}
void index_mul_2d_float_foward_cuda(at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (fea_dim == 64) {
const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_float_dim64<<<BLOCK_NUMS, threads, 0, stream>>>(
out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(),
idx1.data_ptr<int64_t>(), size);
} else {
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_float<<<BLOCK_NUMS, threads, 0, stream>>>(
out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(),
idx1.data_ptr<int64_t>(), size, fea_dim);
}
AT_CUDA_CHECK(cudaGetLastError());
}
void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (fea_dim == 64) {
const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_float_dim64<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);
AT_CUDA_CHECK(cudaGetLastError());
} else {
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_float<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);
}
}
void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (fea_dim == 64) {
const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_grad_float_dim64<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(),
grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);
} else {
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_grad_float<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(),
grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);
}
AT_CUDA_CHECK(cudaGetLastError());
}
void index_mul_2d_half_foward_cuda(at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_half<<<BLOCK_NUMS, threads, 0, stream>>>(
out.data_ptr<at::Half>(), in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(),
idx1.data_ptr<int64_t>(), size, fea_dim);
AT_CUDA_CHECK(cudaGetLastError());
}
void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_half<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(), grad_out.data_ptr<at::Half>(),
in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim);
}
void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_grad_half<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_grad_out.data_ptr<at::Half>(), grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(),
grad_out.data_ptr<at::Half>(), grad_grad_in1.data_ptr<at::Half>(), grad_grad_in2.data_ptr<at::Half>(),
in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim);
AT_CUDA_CHECK(cudaGetLastError());
}
#pragma once
#include <unordered_map>
#include <functional>
#if defined(__HIP_PLATFORM_HCC__)
#include "hip/hip_fp16.h"
#include "hip/hip_bfloat16.h"
#else
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#endif
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Params>
struct LaunchParams{
size_t workspace_bytes;
size_t barrier_size;
cudaDeviceProp * props;
cudaStream_t stream;
Params params;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase {
ParamsBase()
: ctas_per_col(0)
, rows(0)
, cols(0)
, x(nullptr)
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, workspace(nullptr)
, barrier(nullptr)
{
}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x;
void *mu;
void *rs;
void *gamma;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase {
FwdParams()
: ParamsBase()
, z(nullptr)
, beta(nullptr)
, epsilon(0.f)
{
}
// Output of LN FWD.
void *z;
void *beta;
float epsilon;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct BwdParams : public ParamsBase {
BwdParams()
: ParamsBase()
, dz(nullptr)
, dbeta_part(nullptr)
, dgamma_part(nullptr)
, dx(nullptr)
, dbeta(nullptr)
, dgamma(nullptr)
{
}
// Input: gradient wrt. LN FWD output.
void *dz;
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
// Output: Dgrad.
void *dx;
// Output: Wgrad.
void *dbeta;
void *dgamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
using FunctionKey = uint64_t;
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
extern FwdRegistry FWD_FUNCS;
extern BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
using fp32 = float;
using fp16 = half;
#if defined(__HIP_PLATFORM_HCC__)
using bf16 = hip_bfloat16;
#else
using bf16 = nv_bfloat16;
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 4>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 6>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdRegistrar{
FwdRegistrar(FwdFunction f){
uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE);
FWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdRegistrar{
BwdRegistrar(BwdFunction f){
uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE);
BWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment