Commit de19de7a authored by Tri Dao's avatar Tri Dao
Browse files

Implement for bf16

parent 6a77a6da
......@@ -31,12 +31,12 @@ Our tentative roadmap:
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
3. [Jun 2022] Refactor to use Cutlass.
4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
5. [Jun 2022] Support bf16.
5. ~~[Jun 2022] Support bf16~~[Done].
6. ~~[Jul 2022] Implement cross-attention~~[Done].
7. ~~[Jul 2022] Support head dimension 128~~[Done].
8. [Jul 2022] Support SM70 GPUs (V100).
9. [Aug 2022] Fuse rotary embedding.
10. [Aug 2022] Support Attention linear bias (e.g. ALiBi).
10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
## Speedup and Memory Savings
......
......@@ -56,11 +56,13 @@ void set_params_fprop(FMHA_fprop_params &params,
bool is_causal) {
Data_type acc_type = DATA_TYPE_FP32;
Data_type data_type = DATA_TYPE_FP16;
Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16;
// Reset the parameters
memset(&params, 0, sizeof(params));
params.is_bf16 = q.dtype() == torch::kBFloat16;
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
......@@ -192,9 +194,10 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
bool is_dropout = p_dropout > 0.0;
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
TORCH_CHECK(q.dtype() == torch::kFloat16);
TORCH_CHECK(k.dtype() == torch::kFloat16);
TORCH_CHECK(v.dtype() == torch::kFloat16);
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
TORCH_CHECK(k.dtype() == q_dtype);
TORCH_CHECK(v.dtype() == q_dtype);
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
......@@ -326,14 +329,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(q.dtype() == torch::kFloat16);
TORCH_CHECK(k.dtype() == torch::kFloat16);
TORCH_CHECK(v.dtype() == torch::kFloat16);
TORCH_CHECK(out.dtype() == torch::kFloat16);
TORCH_CHECK(dout.dtype() == torch::kFloat16);
TORCH_CHECK(dq.dtype() == torch::kFloat16);
TORCH_CHECK(dk.dtype() == torch::kFloat16);
TORCH_CHECK(dv.dtype() == torch::kFloat16);
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
TORCH_CHECK(k.dtype() == q_dtype);
TORCH_CHECK(v.dtype() == q_dtype);
TORCH_CHECK(out.dtype() == q_dtype);
TORCH_CHECK(dout.dtype() == q_dtype);
TORCH_CHECK(dq.dtype() == q_dtype);
TORCH_CHECK(dk.dtype() == q_dtype);
TORCH_CHECK(dv.dtype() == q_dtype);
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
......
......@@ -123,6 +123,7 @@ struct FMHA_fprop_params : public Qkv_params {
// Random state.
at::PhiloxCudaState philox_args;
bool is_bf16;
bool is_causal;
};
......
......@@ -25,11 +25,13 @@
*
******************************************************************************/
#include <cuda_fp16.h>
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u>
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u, typename elem_type_=__half>
struct FMHA_kernel_traits {
// The CTA description for the 1st GEMM.
......@@ -80,6 +82,8 @@ struct FMHA_kernel_traits {
// The shared memory tile to store dp sum.
using Smem_dp_sum = fmha::Smem_tile_dp_sum<Gmem_tile_q, 2>;
using elem_type = elem_type_;
// Make sure the number of threads match.
static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, "");
......
/* Copyright (c) 2022, Tri Dao.
*/
#include "static_switch.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
......@@ -22,24 +23,21 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2;
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
bool is_causal = params.is_causal;
auto kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>);
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
auto kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst>;
if (params.seqlen_k == blocksize_c) {
kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>);
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/1>;
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/2>;
}
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
......@@ -47,63 +45,66 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
});
}
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_bf16, IsBf16Const, [&] {
using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
auto dprops = at::cuda::getCurrentDeviceProperties();
if (params.d == 16) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else {
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
} else if (params.d == 32) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
} else if (params.d == 64) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.seqlen_k >= 256 ) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
}
} else if (params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
// if (params.d == 64) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else {
// if( params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if( params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor == 0) {
......@@ -111,17 +112,18 @@ void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stre
// // This speeds things up by 2-3% by avoiding register spills, but it
// // uses more shared memory, which is fine on A100 but not other GPUs.
// // For other GPUs, we keep V in registers.
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if (dprops->major == 8 && dprops->minor > 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
// }
// }
// }
// if (params.d == 128) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u_elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
});
}
\ No newline at end of file
......@@ -35,6 +35,14 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first,
inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng &ph,
const int loop_step_idx) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type;
#else
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
assert(is_fp16_type);
using elem_type = __half;
#endif
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
......@@ -106,7 +114,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>;
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false, elem_type>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
......@@ -214,7 +222,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
gmem_q.commit(gemm_q_k.smem_q);
gmem_do.commit(smem_do);
if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, __half>(
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
);
}
......@@ -333,7 +341,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
softmax.template pack<__half>(frag_p);
softmax.template pack<elem_type>(frag_p);
// Store s * dmask to smem for transpose
smem_s.store(frag_p);
......@@ -369,9 +377,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_do.load(frag_do[ki & 1], ki);
if (!Kernel_traits::V_IN_REGS) {
smem_v.load(frag_v[ki & 1], ki);
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else {
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
......@@ -385,9 +393,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
{
int ki = Mma_tile_p::MMAS_K;
if (!Kernel_traits::V_IN_REGS) {
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else {
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
}
}
......@@ -424,7 +432,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
}
}
softmax.template pack<__half>(frag_p);
softmax.template pack<elem_type>(frag_p);
// Store dp to smem for transpose
smem_dp.store(frag_p);
......@@ -442,14 +450,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_kt.load(frag_kt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dq::MMAS_K;
fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
static_assert(Gmem_tile_dq::LOOPS == 1);
......@@ -475,7 +483,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
frag_s[ki][mi].template hrelu_<__half>();
frag_s[ki][mi].template hrelu_<elem_type>();
}
}
}
......@@ -485,13 +493,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_dot.load(frag_dot[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
fmha::gemm_cl<elem_type>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
fmha::gemm_cl<elem_type>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
......@@ -519,7 +527,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
if(l < steps - 1) {
gmem_do.commit(smem_do);
if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, __half>(
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
);
}
......@@ -542,13 +550,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
fmha::gemm_cl<elem_type>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
fmha::gemm_cl<elem_type>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Make sure dQ is in shared memory.
......@@ -575,7 +583,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
}
// Output the values.
gmem_dq.template store<__half>(dq_out, 0);
gmem_dq.template store<elem_type>(dq_out, 0);
// Move to the next part of the output.
gmem_dq.move();
} else {
......@@ -629,11 +637,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// the total amount of shared mem?
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[0], tidx);
smem_dv.template store<__half>(acc_dv);
smem_dv.template store<elem_type>(acc_dv);
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
smem_dk.template store<__half>(acc_dk);
smem_dk.template store<elem_type>(acc_dk);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
......
......@@ -25,6 +25,10 @@
*
******************************************************************************/
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "static_switch.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
......@@ -36,26 +40,8 @@ __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
template<typename Kernel_traits>
void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) {
bool is_causal = launch_params.params.is_causal;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
auto kernel = launch_params.is_dropout
? (is_causal
? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, false>)
: (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, false>))
: (is_causal
? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>)
: (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>));
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
if (configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
......@@ -68,117 +54,133 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
return;
}
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(launch_params.params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(launch_params.return_softmax, ReturnSoftmaxConst, [&] {
auto kernel = &fmha_fprop_fp16_sm80_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst, ReturnSoftmaxConst>;
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(launch_params.params.b, launch_params.params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
});
});
}
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) {
BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] {
using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
auto dprops = at::cuda::getCurrentDeviceProperties();
if (launch_params.params.d == 16) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 32) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 64) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k >= 256 ) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 8 && dprops->minor >= 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
if (launch_params.is_dropout) { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}
}
} else if (launch_params.params.d == 128) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}
}
// if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// if (launch_params.params.d == 64) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if( launch_params.params.seqlen_k >= 256 ) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// // if (launch_params.is_dropout) { // Need to use the same block size as backward
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// // } else {
// // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// // }
// if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// }
// }
// if (launch_params.params.d == 128) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// }
});
}
\ No newline at end of file
......@@ -72,7 +72,7 @@ struct Gemm_Q_K_base {
Smem_tile_k smem_k;
};
template<typename Kernel_traits, bool K_in_regs>
template<typename Kernel_traits, bool K_in_regs, typename elem_type_=__half>
struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
......@@ -81,6 +81,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
using Smem_tile_k = typename Base::Smem_tile_k;
using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p;
using elem_type = elem_type_;
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
// If V is stored in shared memory, we can't load K using the same shared memory.
......@@ -115,12 +116,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
// Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
}
......@@ -132,8 +133,8 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
};
template<typename Kernel_traits>
struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
template<typename Kernel_traits, typename elem_type_>
struct Gemm_Q_K<Kernel_traits, false, elem_type_> : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Smem_tile_o = typename Base::Smem_tile_o;
using Smem_tile_q = typename Base::Smem_tile_q;
......@@ -141,6 +142,7 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p;
using elem_type = elem_type_;
Fragment_k frag_k[2][Mma_tile_p::MMAS_N];
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
......@@ -175,12 +177,12 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
Base::smem_q.load(Base::frag_q[ki & 1], ki);
Base::smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
}
......@@ -197,6 +199,13 @@ constexpr size_t get_dynamic_smem_size(){
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type;
#else
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
assert(is_fp16_type);
using elem_type = __half;
#endif
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
......@@ -231,7 +240,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum;
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS, elem_type>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
......@@ -363,6 +372,10 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Do this part of P = Q * K^T.
gemm_q_k(acc_p);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1));
// }
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
if (!Is_first) { gmem_o_tmp.load(out, 0); }
......@@ -466,7 +479,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
softmax.template pack<__half>(frag_p);
softmax.template pack<elem_type>(frag_p);
if (Return_softmax) {
gmem_s.store(frag_p, mask);
gmem_s.move();
......@@ -482,7 +495,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
frag_p[ki][mi].template hrelu_<__half>();
frag_p[ki][mi].template hrelu_<elem_type>();
}
}
}
......@@ -494,7 +507,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm_cl<__half>(acc_o, frag_p[ki], frag_v[ki]);
fmha::gemm_cl<elem_type>(acc_o, frag_p[ki], frag_v[ki]);
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
......@@ -605,7 +618,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Output the values.
if (is_final_write) {
gmem_o.template store<__half>(out, 0);
gmem_o.template store<elem_type>(out, 0);
gmem_o.move();
} else {
gmem_o_tmp.store(out, 0);
......
......@@ -32,6 +32,7 @@
#include <stdlib.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -50,7 +51,7 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
enum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };
enum Data_type { DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -60,6 +61,11 @@ static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
uint16_t h = reinterpret_cast<const uint16_t &>( x );
ushort2 h2 = { h, h };
alpha = reinterpret_cast<const uint32_t &>( h2 );
} else if( dtype == DATA_TYPE_BF16 ) {
__nv_bfloat16 x = __float2bfloat16( norm );
uint16_t h = reinterpret_cast<const uint16_t &>( x );
ushort2 h2 = { h, h };
alpha = reinterpret_cast<const uint32_t &>( h2 );
} else if( dtype == DATA_TYPE_FP32 ) {
alpha = reinterpret_cast<const uint32_t &>( norm );
} else if( dtype == DATA_TYPE_INT32 ) {
......@@ -78,6 +84,8 @@ static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {
return n * 4;
case DATA_TYPE_FP16:
return n * 2;
case DATA_TYPE_BF16:
return n * 2;
case DATA_TYPE_INT32:
return n * 4;
case DATA_TYPE_INT8:
......
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment