Unverified Commit cc92a4b4 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #55 from ROCmSoftwarePlatform/IFU-master-2021-10-15

IFU-2021-10-15 (+ remove redundant defines + C10_CUDA_CHECK)
parents 1e0f9bc6 fec3141c
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 256 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_256_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_256_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 384 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_384_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_384_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
template<int CHUNKS>
__global__
void fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params){
fmha::compute_dv_1xN_nl<CHUNKS, Kernel_traits>(params);
fmha::compute_dq_dk_1xN_nl<CHUNKS, Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 512 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_512_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 512 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
if( num_chunks == 2 ) {
kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
}else if( num_chunks == 3 ) {
kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>;
} else {
assert(false && "Unsupperted number of chunks");
}
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b, num_chunks);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Params>
inline __device__ void compute_dv_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dv =
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);
static_assert(Cta_tile_dv::N == 64);
static_assert(Cta_tile_dv::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
// using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The shared memory tile to reload Q as fragment b.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store dV.
using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle dV.
using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;
static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);
static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
Smem_tile_qt smem_qt(&smem_[0], tidx);
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];
static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);
static_assert(Mma_tile_dv::MMAS_K == 1);
smem_qt.load(frag_qt[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];
smem_k.load(frag_k[0], 0);
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
Softmax softmax(
params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx);
enum { THREADS_PER_ROW = 32 };
enum { M = Mma_tile_p::MMAS_M };
enum { N = Mma_tile_p::MMAS_N };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Load over the entire sequence length.
for( int l = 0; l < STEPS; l++ ) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
// Load S
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Store s * dmask to smem for transpose
smem_s.store(s_regs);
// Declare the accumulators for the 1st gemm.
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if( l < STEPS - 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
float s_mat[2 * M][4 * N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 &dst = s_regs[mi][ni];
fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);
fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);
fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);
fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);
}
}
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];
const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;
const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;
s_dmask = fabsf(s_dmask);
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask);
}
}
}
}
float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ;
softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;
}
}
}
}
typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];
smem_s.load(frag_s);
for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {
for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) {
for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) {
frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);
frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
#pragma unroll
for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if(l < STEPS - 1) {
gmem_q.commit(smem_q);
}
// Make sure we are reading from the correct buffer.
smem_q.move_to_next_read_buffer();
smem_qt.move_to_next_read_buffer();
// Make sure the data is in shared memory.
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_q.load(frag_q[0], 0);
smem_k.load(frag_k[0], 0);
smem_qt.load(frag_qt[0], 0);
} // Outer loop over the sequence length.
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);
smem_dv.store(acc_dv);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Qkv_params dv_params;
dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
gmem_dv.store(dv_out);
}
template<typename Kernel_traits, typename Params>
inline __device__ void compute_dq_dk_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dk =
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);
static_assert(Cta_tile_dk::N == 64);
static_assert(Cta_tile_dk::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
// using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
using Gmem_tile_o = fmha::Gmem_tile_dq<Cta_tile_o>;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store dK.
using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle dK.
using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;
static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);
static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);
// The shared memory tile to reload Q transposed.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
enum { M = Mma_tile_p::MMAS_M };
enum { N = Mma_tile_p::MMAS_N };
static_assert(M == Mma_tile_o::MMAS_M);
static_assert(N == Mma_tile_o::MMAS_K);
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
Smem_tile_qt smem_qt(&smem_[0], tidx);
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Load dP
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
gmem_s.move();
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Make sure the data is in shared memory.
__syncthreads();
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];
smem_qt.load(frag_qt[0], 0);
typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];
smem_k.load(frag_k[0], 0);
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
enum { THREADS_PER_ROW = 32 };
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length.
for( int l=0;l<STEPS;l++) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
// Pack dP as Fragment_a
fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 &dst = s_regs[mi][ni];
frag_p[ni][mi].reg(0) = dst.x; // row 0, cols 0,1
frag_p[ni][mi].reg(1) = dst.z; // row 8, cols 0,1
frag_p[ni][mi].reg(2) = dst.y; // row 0, cols 8,9
frag_p[ni][mi].reg(3) = dst.w; // row 8, cols 8,9
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T. dQ = dP x dK
#pragma unroll
for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_o::MMAS_K;
fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);
}
// Store dP to smem for transpose
smem_s.store(s_regs);
if(l < STEPS - 1) {
// Load next part of S
gmem_s.load(s_regs, mask);
gmem_s.move();
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];
smem_s.load(frag_s);
#pragma unroll
for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dk::MMAS_K;
fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if( l < STEPS - 1) {
gmem_q.commit(smem_q);
}
// Make sure the data is in shared memory.
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_qt.load(frag_qt[0], 0);
smem_k.load(frag_k[0], 0);
} // Outer loop over the sequence length.
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[0], tidx);
smem_dk.store(acc_dk);
__syncthreads();
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
Qkv_params dk_params;
dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
gmem_dk.store(dk_out);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int CHUNKS, typename Kernel_traits, typename Params>
inline __device__ void compute_dv_1xN_nl(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dv = fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);
static_assert(Cta_tile_dv::N == 64);
static_assert(Cta_tile_dv::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The shared memory tile to reload Q as fragment b.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store dV.
using Gmem_tile_dv = fmha::Gmem_tile_qkv<typename Kernel_traits::Cta_tile_o,
fmha::BITS_PER_ELEMENT_B,
Cta_tile_p::N, //S,
Cta_tile_p::K, //D,
2*CHUNKS>;
// The shared memory tile to swizzle dV.
using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;
static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);
static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the chunk.
const int bidc = blockIdx.z;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
Smem_tile_qt smem_qt(&smem_[0], tidx);
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
Noloop nl_traits(bidc);
nl_traits.move_all(gmem_q, gmem_s);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];
static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);
static_assert(Mma_tile_dv::MMAS_K == 1);
smem_qt.load(frag_qt[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];
smem_k.load(frag_k[0], 0);
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
Softmax softmax(
params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx);
enum { THREADS_PER_ROW = 32 };
enum { M = Mma_tile_p::MMAS_M };
enum { N = Mma_tile_p::MMAS_N };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);
// Load over the entire sequence length.
for(int l = 0; l < nl_traits.num_steps_;l++) {
const int loop = nl_traits.offset_loop_count(l);
if( loop >= binfo.actual_seqlen ) break;
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
smem_s.store(s_regs);
// Declare the accumulators for the 1st gemm.
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if(l < nl_traits.num_steps_ - 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
float s_mat[2 * M][4 * N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 &dst = s_regs[mi][ni];
fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);
fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);
fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);
fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);
}
}
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];
const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;
const float d_s= drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;
s_dmask = fabsf(s_dmask);
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * (s_dmask);
}
}
}
}
float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ;
softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;
}
}
}
}
typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];
smem_s.load(frag_s);
for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {
for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) {
for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) {
frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);
frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
static_assert(Mma_tile_dv::MMAS_K == 1); // DEBUG
#pragma unroll
for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if(l < nl_traits.num_steps_ - 1) {
gmem_q.commit(smem_q);
}
// Make sure we are reading from the correct buffer.
smem_q.move_to_next_read_buffer();
smem_qt.move_to_next_read_buffer();
// Make sure the data is in shared memory.
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_q.load(frag_q[0], 0);
smem_k.load(frag_k[0], 0);
smem_qt.load(frag_qt[0], 0);
} // Outer loop over the sequence length.
// Epilogue for dV = (S * D)' * dout'. We're fully exposed to this!
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);
smem_dv.store(acc_dv);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Qkv_params dv_params;
dv_params.qkv_ptr = params.dkv_ptr;
dv_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half);
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, nl_traits.get_idx_dv(), binfo, tidx);
gmem_dv.store(dv_out);
}
template<int CHUNKS, typename Kernel_traits, typename Params>
inline __device__ void compute_dq_dk_1xN_nl(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dk = fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);
static_assert(Cta_tile_dk::N == 64);
static_assert(Cta_tile_dk::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = Gmem_tile_dq<Cta_tile_o>;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store dK.
using Gmem_tile_dk = fmha::Gmem_tile_qkv<typename Kernel_traits::Cta_tile_o,
fmha::BITS_PER_ELEMENT_B,
Cta_tile_p::N, //S,
Cta_tile_p::K, //D,
2*CHUNKS>;
// The shared memory tile to swizzle dK.
using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;
static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);
static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);
// The shared memory tile to reload Q transposed.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
// The global memory tile to load dP, stored in S
using Gmem_tile_s = Gmem_tile_mma_s<Cta_tile_p>;
// The shared memory tile to transpose dP.
using Smem_tile_st = Smem_tile_mma_transposed<Cta_tile_p>;
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
enum { M = Mma_tile_p::MMAS_M };
enum { N = Mma_tile_p::MMAS_N };
static_assert(M == Mma_tile_o::MMAS_M);
static_assert(N == Mma_tile_o::MMAS_K);
// Shared memory.
extern __shared__ char smem_[];
const int bidc = blockIdx.z;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q (as B).
Smem_tile_qt smem_qt(&smem_[0], tidx);
// Allocate the global memory tile loader for dP.
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Allocate the shared memory tile loader for dP.
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
Noloop nl_traits(bidc);
nl_traits.move_all(gmem_q, gmem_o, gmem_s);
// Trigger the loads for Q.
gmem_q.load(smem_qt);
// Trigger the loads for K.
gmem_k.load(smem_k);
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_qt);
gmem_k.commit(smem_k);
// Make sure the data is in shared memory.
__syncthreads();
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];
smem_qt.load(frag_qt[0], 0);
typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];
smem_k.load(frag_k[0], 0);
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
enum { THREADS_PER_ROW = 32 };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length.
for(int l=0;l < nl_traits.num_steps_; l++) {
// Pack dP as Fragment_a
fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 &dst = s_regs[mi][ni];
frag_p[ni][mi].reg(0) = dst.x;
frag_p[ni][mi].reg(1) = dst.z;
frag_p[ni][mi].reg(2) = dst.y;
frag_p[ni][mi].reg(3) = dst.w;
}
}
smem_s.store(s_regs);
if(l < nl_traits.num_steps_- 1) {
// Load next part of S
gmem_s.move();
gmem_s.load(s_regs, mask);
// Trigger the load for the next Q values.
smem_qt.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_qt);
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T. dQ = dP x dK
#pragma unroll
for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_o::MMAS_K;
fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);
}
static_assert(Gmem_tile_o::LOOPS == 1); //DEBUG
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];
smem_s.load(frag_s);
static_assert(Mma_tile_dk::MMAS_K == 1); // DEBUG
#pragma unroll
for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dk::MMAS_K;
fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if(l < nl_traits.num_steps_- 1) {
gmem_q.commit(smem_qt);
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_qt.load(frag_qt[0], 0);
smem_k.load(frag_k[0], 0);
}
} // Outer loop over the sequence length.
// Epilogue for dK = dP' * dq. We're fully exposed to this!
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[0], tidx);
smem_dk.store(acc_dk);
__syncthreads();
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
Qkv_params dk_params;
dk_params.qkv_ptr = params.dkv_ptr;
dk_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half);
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, nl_traits.get_idx_dk(), binfo, tidx);
gmem_dk.store(dk_out);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_128_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
extern "C" __global__ void fmha_fprop_fp16_128_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
}
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_128_64_sm80_train_kernel : &fmha_fprop_fp16_128_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_256_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
extern "C" __global__ void fmha_fprop_fp16_256_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
}
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_256_64_sm80_train_kernel : &fmha_fprop_fp16_256_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN_reload_v.h"
using Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_384_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
extern "C" __global__ void fmha_fprop_fp16_384_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
}
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_384_64_sm80_train_kernel : &fmha_fprop_fp16_384_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_v + smem_size_o + smem_size_softmax;
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
}
template<int CHUNKS>
__global__ void fmha_fprop_fp16_512_64_sm80_train_nl_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN_nl<CHUNKS,Kernel_traits, true>(params);
}
template<int CHUNKS>
__global__ void fmha_fprop_fp16_512_64_sm80_predict_nl_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN_nl<CHUNKS, Kernel_traits, false>(params);
}
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_kernel : &fmha_fprop_fp16_512_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2> : &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;
if( num_chunks == 2 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<2>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<2>;
} else if( num_chunks == 3 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<3>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<3>;
} else if( num_chunks == 4 ) {
kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_nl_kernel<4>
: &fmha_fprop_fp16_512_64_sm80_predict_nl_kernel<4>;
} else {
assert(false && "Unsupported num_chunks");
}
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b, num_chunks);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Kernel_traits, bool Is_training, typename Params> inline __device__ void device_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));
Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = nullptr;
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];
} else {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Trigger the loads for K.
gmem_v.load(smem_v);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Commit the data for V to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_v.commit(smem_v);
}
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_k.load(frag_k[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_v.commit(smem_v);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Create the object to do the softmax.
using Softmax = fmha::Softmax< Cta_tile_p, Kernel_traits>;
Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
enum { THREADS_PER_ROW = 32 };
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Load over the entire sequence length.
for( int l = 0; l < STEPS; l++ ) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Load the mask for that iteration.
mask.load(l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Max_>(p_max);
// Make sure we are done reading shared memory.
__syncthreads();
// Compute the exponential value.
softmax.apply_exp(p_max);
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Sum_>(p_sum);
// Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum);
if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
}
// Trigger the load for the next Q values.
if(l < STEPS - 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
using Frag_p = fmha::Fragment_a< fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
softmax.pack(frag_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
// Commit the values for Q into shared memory.
if(l < STEPS - 1) {
gmem_q.commit(smem_q);
}
// Make sure the data is in shared memory.
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_q.load(frag_q[0], 0);
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int CHUNKS, typename Kernel_traits, bool Is_training, typename Params>
inline __device__ void device_1xN_nl(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store S/D.
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
// Shared memory.
extern __shared__ char smem_[];
const int bidc = blockIdx.z;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
Noloop nl_traits(bidc);
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));
fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = nullptr;
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];
} else {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
nl_traits.move_all(gmem_q, gmem_o, gmem_s);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Trigger the loads for K.
gmem_v.load(smem_v);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Commit the data for V to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_v.commit(smem_v);
}
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_k.load(frag_k[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_v.commit(smem_v);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
// The number of threads per row.
enum { THREADS_PER_ROW = 32 };
// Load over the entire sequence length.
for(int l = 0; l < nl_traits.num_steps_;l++) {
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Trigger the load for the next Q values.
if( l < nl_traits.num_steps_- 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Load the mask for that iteration.
mask.load(nl_traits.loop_offset_ + l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Max_>(p_max);
// Make sure we are done reading shared memory.
__syncthreads();
// Compute the exponential value.
softmax.apply_exp(p_max);
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Sum_>(p_sum);
// Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum);
if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
softmax.pack(frag_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
// Commit the values for Q into shared memory.
if( l < nl_traits.num_steps_- 1) {
gmem_q.commit(smem_q);
__syncthreads();
smem_q.load(frag_q[0], 0);
}
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Kernel_traits, bool Is_training, typename Params> inline __device__ void device_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
Mask<Cta_tile_p> mask(params, binfo, tidx);
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));
static_assert(2 * Mma_tile_p::MMAS_M * 4 * Mma_tile_p::MMAS_N <= 64);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[0], tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = nullptr;
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_v_ = &smem_[0];
} else {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];
}
static_assert(Kernel_traits::SHARE_SMEM_FOR_K_AND_V);
static_assert(Smem_tile_k::BYTES_PER_TILE == Smem_tile_v::BYTES_PER_TILE);
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Trigger the loads for K.
gmem_v.load(smem_v);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Commit the data for V to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_v.commit(smem_v);
}
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[1][Mma_tile_p::MMAS_M];
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_k.load(frag_k[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_v.commit(smem_v);
}
enum { BITS_PER_ELT_S = sizeof(typename fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Create the object to do the softmax.
using Softmax = fmha::Softmax< Cta_tile_p, Kernel_traits>;
Softmax softmax(params, &smem_[Smem_tile_v::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
constexpr int SMEM_BYTES_SOFTMAX = Softmax::ELEMENTS * sizeof(float);
static_assert(SMEM_BYTES_SOFTMAX == Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float));
enum { THREADS_PER_ROW = 32 };
const float pinv = 1.f / params.p_dropout;
// Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {
if( loop >= binfo.actual_seqlen )
break;
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[0], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[0], frag_k[ki]);
}
// Load the mask for that iteration.
mask.load(outer);
// Convert from the accumulator typ e to FP32 for Softmax.
softmax.unpack(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
static_assert(2 * Mma_tile_p::MMAS_M * 4 * Mma_tile_p::MMAS_N <= 64);
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Max_>(p_max);
// Make sure we are done reading shared memory.
__syncthreads();
// Compute the exponential value.
softmax.apply_exp(p_max);
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Sum_>(p_sum);
// Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum);
__syncthreads();
if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
}
// Trigger the load for the next Q values.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
typename Smem_tile_v::Fragment frag_v[1][Mma_tile_o::MMAS_N];
using Frag_p = fmha::Fragment_a< fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
softmax.pack(frag_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of V values.
smem_v.load(frag_v[0], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_o, frag_p[ki], frag_v[0]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Always sync after last iter: shared smem_q and smem_o!
__syncthreads();
// Output the values.
gmem_o.store(out, ii);
}
// same smem as o
// Move to the next part of the output.
gmem_o.move();
// Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
gmem_q.commit(smem_q);
}
// Make sure the data is in shared memory.
__syncthreads();
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <multihead_attn/philox.h>
#include <fmha.h>
#include <fmha/utils.h>
#include <fmha/smem_tile.h>
#include <fmha/gmem_tile.h>
#include <fmha/mask.h>
#include <fmha/softmax.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_CTA>
struct BlockInfoPadded {
template<typename Params>
__device__ BlockInfoPadded(const Params &params,
const int bidb,
const int bidh,
const int tidx)
: bidb(bidb), bidh(bidh), h(params.h) {
// The block index.
sum_s = params.cu_seqlens[bidb];
actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;
bidx = sum_s * params.h + bidh;
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
}
__device__ bool stop_early() const {
return actual_seqlen == 0;
}
int actual_seqlen;
int bidx;
int sum_s;
int bidh;
int bidb;
int tidx_global;
int h;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int CHUNKS, typename Cta_tile>
struct Noloop_traits{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum{ STEP = Cta_tile::M };
enum{ SEQLEN = Cta_tile::N };
// The size of the subsequence this CTA is processing
enum { SUBSEQ = SEQLEN / CHUNKS };
static_assert(SUBSEQ * CHUNKS == SEQLEN);
// The number of steps to process the subsequence
enum { NUM_STEPS = SUBSEQ / STEP };
static_assert(NUM_STEPS * Cta_tile::M == SUBSEQ);
inline __device__ Noloop_traits(const int bidc)
: loop_offset_(NUM_STEPS * bidc)
, bidc_(bidc) {
}
template<typename ... Tiles>
inline __device__ void move_all(Tiles & ... tiles) const {
using expand_type = int[];
for( int s = 0; s < loop_offset_; s++ ) {
expand_type{ (tiles.move(), 0)... };
}
}
inline __device__ int get_idx_dk() const {
//return bidc_;
return bidc_ * 2 + 0;
}
inline __device__ int get_idx_dv() const {
//return CHUNKS + bidc_;
return bidc_ * 2 + 1;
}
inline __device__ int offset_loop_count(const int l) {
// convert loop counter to position in the outer sequence
return (loop_offset_ + l) * STEP;
}
const int loop_offset_;
const uint32_t bidc_;
const int num_steps_ = NUM_STEPS;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Noloop_traits<3, Cta_tile>{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum{ STEP = Cta_tile::M };
enum{ SEQLEN = Cta_tile::N };
static_assert(STEP == 16 && SEQLEN == 512);
inline __device__ Noloop_traits(const int bidc)
: bidc_(bidc)
, num_steps_(bidc < 2 ? 11 : 10)
, loop_offset_(bidc * 11) {
}
template<typename ... Tiles>
inline __device__ void move_all(Tiles & ... tiles) const {
using expand_type = int[];
for( int s = 0; s < loop_offset_; s++ ) {
expand_type{ (tiles.move(), 0)... };
}
}
inline __device__ int get_idx_dk() const {
//return bidc_;
return bidc_ * 2 + 0;
}
inline __device__ int get_idx_dv() const {
//return CHUNKS + bidc_;
return bidc_ * 2 + 1;
}
inline __device__ int offset_loop_count(const int l) {
// convert loop counter to position in the outer sequence
return (loop_offset_ + l) * STEP;
}
const int loop_offset_;
const uint32_t bidc_;
const int num_steps_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
inline __device__ float4 ldg128(const void *ptr) {
return *static_cast<const float4 *>(ptr);
}
inline __device__ void stg128(void *ptr, const float4 &data) {
*static_cast<float4 *>(ptr) = data;
}
template<typename T, int THREADS, int HIDDEN_SIZE, int CHUNKS>
__global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void *__restrict__ out,
const void *__restrict__ in,
const int *__restrict__ cu_seqlens,
const int batch_size) {
enum { BYTES_PER_LDG = 16 };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) };
// One CTA hidden vector for K and V
enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 };
// The stride in bytes in dQKV
enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) };
// The offset in bytes in dQKV to the dKV part for non-interleaved heads
enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) };
static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T));
// Size in bytes of the input tile
enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW };
enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG };
enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA };
static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW);
union Vec_t {
float4 raw;
T elt[NUM_ELTS];
};
// ZERO-OUT invalid positions in dQKV
const int total = cu_seqlens[batch_size];
if(blockIdx.x >= total){
enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) };
enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG };
const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f);
char *base_ptr = static_cast<char *>(out) + blockIdx.x * OUT_STRIDE_BYTES;
for(int tidx = threadIdx.x; tidx < STGS; tidx += THREADS){
stg128(base_ptr + tidx * BYTES_PER_LDG, zeros);
}
return;
}
// SETUP
const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG;
const char *ptr_in = static_cast<const char *>(in) + offset_in;
const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG;
char *ptr_out = static_cast<char *>(out) + OUT_OFFSET_KV_BYTES + offset_out;
// LOAD
Vec_t local_in[CHUNKS][LDGS];
#pragma unroll
for( int c = 0; c < CHUNKS; c++ ) {
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA;
local_in[c][l].raw = ldg128(ptr_in + offset);
}
}
// UNPACK
float acc[LDGS][NUM_ELTS];
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
acc[l][e] = float(local_in[0][l].elt[e]);
}
}
// COMPUTE
#pragma unroll
for( int c = 1; c < CHUNKS; c++ ) {
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
acc[l][e] += float(local_in[c][l].elt[e]);
}
}
}
// PACK
Vec_t local_out[LDGS];
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
#pragma unroll
for( int e = 0; e < NUM_ELTS; e++ ) {
local_out[l].elt[e] = T(acc[l][e]);
}
}
// STORE
#pragma unroll
for( int l = 0; l < LDGS; l++ ) {
const int offset = l * BYTES_PER_CTA;
stg128(ptr_out + offset, local_out[l].raw);
}
}
void fmha_run_noloop_reduce(void *out,
const void *in,
const int *cu_seqlens,
const int hidden_size,
const int batch_size,
const int total,
const int num_chunks,
cudaStream_t stream) {
const int blocks = total;
if(hidden_size == 1024){
constexpr int HIDDEN_SIZE = 1024;
constexpr int THREADS = 256;
if( num_chunks == 2 ) {
fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 2><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);
} else if( num_chunks == 3 ) {
fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 3><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);
} else {
assert(false && "Unsupported num_chunks");
}
}else{
assert(false && "Unsupported hidden_size");
}
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define FMHA_CHECK_CUDA( call ) \
do { \
cudaError_t status_ = call; \
if( status_ != cudaSuccess ) { \
fprintf( stderr, \
"CUDA error (%s:%d): %s\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString( status_ ) ); \
exit( 1 ); \
} \
} while( 0 )
////////////////////////////////////////////////////////////////////////////////////////////////////
enum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
if( dtype == DATA_TYPE_FP16 ) {
half x = __float2half_rn( norm );
uint16_t h = reinterpret_cast<const uint16_t &>( x );
ushort2 h2 = { h, h };
alpha = reinterpret_cast<const uint32_t &>( h2 );
} else if( dtype == DATA_TYPE_FP32 ) {
alpha = reinterpret_cast<const uint32_t &>( norm );
} else if( dtype == DATA_TYPE_INT32 ) {
int32_t inorm = static_cast<int32_t>( norm );
alpha = reinterpret_cast<const uint32_t &>( inorm );
} else {
assert( false );
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {
switch( dtype ) {
case DATA_TYPE_FP32:
return n * 4;
case DATA_TYPE_FP16:
return n * 2;
case DATA_TYPE_INT32:
return n * 4;
case DATA_TYPE_INT8:
return n;
default:
assert( false );
return 0;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include <c10/cuda/CUDACachingAllocator.h>
#include "batch_norm.h"
......@@ -26,23 +24,20 @@ static size_t round_up_to_multiple(size_t x, int multiple) {
return ((x + multiple - 1) / multiple) * multiple;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
dataPtr = allocator.allocate(size);
data = dataPtr.get();
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
THCudaFree(at::globalContext().lazyInitCUDA(), data);
}
}
~Workspace() = default;
size_t size;
void* data;
c10::DataPtr dataPtr;
};
// Return {y}
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include <c10/cuda/CUDACachingAllocator.h>
#include "batch_norm_add_relu.h"
......@@ -27,23 +25,20 @@ static size_t round_up_to_multiple(size_t x, int multiple) {
return ((x + multiple - 1) / multiple) * multiple;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
dataPtr = allocator.allocate(size);
data = dataPtr.get();
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
THCudaFree(at::globalContext().lazyInitCUDA(), data);
}
}
~Workspace() = default;
size_t size;
void* data;
c10::DataPtr dataPtr;
};
// Return {y}
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include <cuda.h>
......
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "softmax.h"
#include "dropout.h"
......
......@@ -9,8 +9,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h>
#include <THC/THCGeneral.h>
const int UNROLL = 4;
template <
......@@ -207,7 +205,7 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
//number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
......@@ -222,7 +220,7 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
}
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, p, rng_engine_inputs);
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
}
template <
......@@ -245,7 +243,7 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
//number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
......@@ -260,7 +258,7 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
}
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs);
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
}
template <
......@@ -281,7 +279,7 @@ void apex_add_cuda(scalar_t const *inputs,
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
apex_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, totalElements);
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
}
template<typename scalar_t,
......@@ -302,7 +300,7 @@ void apex_masked_scale_cuda(scalar_t const *inputs,
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, scale);
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment