Commit de19de7a authored by Tri Dao's avatar Tri Dao
Browse files

Implement for bf16

parent 6a77a6da
...@@ -31,12 +31,12 @@ Our tentative roadmap: ...@@ -31,12 +31,12 @@ Our tentative roadmap:
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done]. 2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
3. [Jun 2022] Refactor to use Cutlass. 3. [Jun 2022] Refactor to use Cutlass.
4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done]. 4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
5. [Jun 2022] Support bf16. 5. ~~[Jun 2022] Support bf16~~[Done].
6. ~~[Jul 2022] Implement cross-attention~~[Done]. 6. ~~[Jul 2022] Implement cross-attention~~[Done].
7. ~~[Jul 2022] Support head dimension 128~~[Done]. 7. ~~[Jul 2022] Support head dimension 128~~[Done].
8. [Jul 2022] Support SM70 GPUs (V100). 8. [Jul 2022] Support SM70 GPUs (V100).
9. [Aug 2022] Fuse rotary embedding. 9. [Aug 2022] Fuse rotary embedding.
10. [Aug 2022] Support Attention linear bias (e.g. ALiBi). 10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
## Speedup and Memory Savings ## Speedup and Memory Savings
......
...@@ -56,11 +56,13 @@ void set_params_fprop(FMHA_fprop_params &params, ...@@ -56,11 +56,13 @@ void set_params_fprop(FMHA_fprop_params &params,
bool is_causal) { bool is_causal) {
Data_type acc_type = DATA_TYPE_FP32; Data_type acc_type = DATA_TYPE_FP32;
Data_type data_type = DATA_TYPE_FP16; Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16;
// Reset the parameters // Reset the parameters
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));
params.is_bf16 = q.dtype() == torch::kBFloat16;
// Set the pointers and strides. // Set the pointers and strides.
params.q_ptr = q.data_ptr(); params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr(); params.k_ptr = k.data_ptr();
...@@ -192,9 +194,10 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -192,9 +194,10 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
bool is_dropout = p_dropout > 0.0; bool is_dropout = p_dropout > 0.0;
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax); Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
TORCH_CHECK(q.dtype() == torch::kFloat16); auto q_dtype = q.dtype();
TORCH_CHECK(k.dtype() == torch::kFloat16); TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
TORCH_CHECK(v.dtype() == torch::kFloat16); TORCH_CHECK(k.dtype() == q_dtype);
TORCH_CHECK(v.dtype() == q_dtype);
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
...@@ -326,14 +329,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -326,14 +329,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool is_dropout = p_dropout > 0.0; bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(q.dtype() == torch::kFloat16); auto q_dtype = q.dtype();
TORCH_CHECK(k.dtype() == torch::kFloat16); TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
TORCH_CHECK(v.dtype() == torch::kFloat16); TORCH_CHECK(k.dtype() == q_dtype);
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(v.dtype() == q_dtype);
TORCH_CHECK(dout.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == q_dtype);
TORCH_CHECK(dq.dtype() == torch::kFloat16); TORCH_CHECK(dout.dtype() == q_dtype);
TORCH_CHECK(dk.dtype() == torch::kFloat16); TORCH_CHECK(dq.dtype() == q_dtype);
TORCH_CHECK(dv.dtype() == torch::kFloat16); TORCH_CHECK(dk.dtype() == q_dtype);
TORCH_CHECK(dv.dtype() == q_dtype);
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
...@@ -720,4 +724,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -720,4 +724,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bwd", &mha_bwd, "Backward pass"); m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)"); m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)");
m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)"); m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)");
} }
\ No newline at end of file
...@@ -123,6 +123,7 @@ struct FMHA_fprop_params : public Qkv_params { ...@@ -123,6 +123,7 @@ struct FMHA_fprop_params : public Qkv_params {
// Random state. // Random state.
at::PhiloxCudaState philox_args; at::PhiloxCudaState philox_args;
bool is_bf16;
bool is_causal; bool is_causal;
}; };
......
...@@ -25,11 +25,13 @@ ...@@ -25,11 +25,13 @@
* *
******************************************************************************/ ******************************************************************************/
#include <cuda_fp16.h>
#pragma once #pragma once
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u> template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u, typename elem_type_=__half>
struct FMHA_kernel_traits { struct FMHA_kernel_traits {
// The CTA description for the 1st GEMM. // The CTA description for the 1st GEMM.
...@@ -80,6 +82,8 @@ struct FMHA_kernel_traits { ...@@ -80,6 +82,8 @@ struct FMHA_kernel_traits {
// The shared memory tile to store dp sum. // The shared memory tile to store dp sum.
using Smem_dp_sum = fmha::Smem_tile_dp_sum<Gmem_tile_q, 2>; using Smem_dp_sum = fmha::Smem_tile_dp_sum<Gmem_tile_q, 2>;
using elem_type = elem_type_;
// Make sure the number of threads match. // Make sure the number of threads match.
static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, "");
......
/* Copyright (c) 2022, Tri Dao. /* Copyright (c) 2022, Tri Dao.
*/ */
#include "static_switch.h"
#include "fmha.h" #include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h" #include "fmha_dgrad_kernel_1xN_loop.h"
...@@ -22,106 +23,107 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_ ...@@ -22,106 +23,107 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N); static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2; constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2;
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
bool is_causal = params.is_causal;
auto kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>);
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.seqlen_k == blocksize_c) {
kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>);
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
}
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv); // printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute( bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
} BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
dim3 grid(params.b, params.h); auto kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params); Kernel_traits, IsDropoutConst, IsCausalConst>;
FMHA_CHECK_CUDA(cudaPeekAtLastError()); if (params.seqlen_k == blocksize_c) {
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/1>;
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/2>;
}
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
});
} }
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) { void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
if (params.d == 16) { BOOL_SWITCH(params.is_bf16, IsBf16Const, [&] {
if( params.seqlen_k == 128 ) { using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u>; auto dprops = at::cuda::getCurrentDeviceProperties();
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); if (params.d == 16) {
} else if( params.seqlen_k == 256 ) { if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else { } else if( params.seqlen_k == 256 ) {
// TD [2022-05-15] 512 gives wrong results rn using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
} else if (params.d == 32) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
} else if (params.d == 64) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.seqlen_k >= 256 ) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (dprops->major == 8 && dprops->minor > 0) { } else {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>; // TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (dprops->major == 7 && dprops->minor == 5) { }
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; } else if (params.d == 32) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} }
} else if (params.d == 64) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.seqlen_k >= 256 ) {
if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
}
} else if (params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} }
} else if (params.d == 128) { // if (params.d == 64) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>; // if (dprops->major == 7 && dprops->minor == 5) {
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
} // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// if (params.d == 64) { // } else {
// auto dprops = at::cuda::getCurrentDeviceProperties(); // if( params.seqlen_k == 128 ) {
// if (dprops->major == 7 && dprops->minor == 5) { // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); // } else if( params.seqlen_k >= 256 ) {
// } else { // if (dprops->major == 8 && dprops->minor == 0) {
// if( params.seqlen_k == 128 ) { // // Don't share smem for K & V, and don't keep V in registers
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; // // This speeds things up by 2-3% by avoiding register spills, but it
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); // // uses more shared memory, which is fine on A100 but not other GPUs.
// } else if( params.seqlen_k >= 256 ) { // // For other GPUs, we keep V in registers.
// if (dprops->major == 8 && dprops->minor == 0) { // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
// // Don't share smem for K & V, and don't keep V in registers // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// // This speeds things up by 2-3% by avoiding register spills, but it // } else if (dprops->major == 8 && dprops->minor > 0) {
// // uses more shared memory, which is fine on A100 but not other GPUs. // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
// // For other GPUs, we keep V in registers. // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>; // }
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); // }
// } else if (dprops->major == 8 && dprops->minor > 0) { // }
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>; // }
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); // if (params.d == 128) {
// } // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u_elem_type>;
// } // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } // }
// } });
// if (params.d == 128) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
} }
\ No newline at end of file
...@@ -35,6 +35,14 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first, ...@@ -35,6 +35,14 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first,
inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng &ph, inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng &ph,
const int loop_step_idx) { const int loop_step_idx) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type;
#else
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
assert(is_fp16_type);
using elem_type = __half;
#endif
// The description of the CTA tile for the 1st batched GEMM. // The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p; using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM. // The description of the CTA tile for the 2nd batched GEMM.
...@@ -106,7 +114,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -106,7 +114,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>; // using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>; using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false, elem_type>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>; using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
...@@ -214,7 +222,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -214,7 +222,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
gmem_q.commit(gemm_q_k.smem_q); gmem_q.commit(gemm_q_k.smem_q);
gmem_do.commit(smem_do); gmem_do.commit(smem_do);
if (Is_first) { if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, __half>( dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
); );
} }
...@@ -333,7 +341,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -333,7 +341,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N); static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
softmax.template pack<__half>(frag_p); softmax.template pack<elem_type>(frag_p);
// Store s * dmask to smem for transpose // Store s * dmask to smem for transpose
smem_s.store(frag_p); smem_s.store(frag_p);
...@@ -369,9 +377,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -369,9 +377,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_do.load(frag_do[ki & 1], ki); smem_do.load(frag_do[ki & 1], ki);
if (!Kernel_traits::V_IN_REGS) { if (!Kernel_traits::V_IN_REGS) {
smem_v.load(frag_v[ki & 1], ki); smem_v.load(frag_v[ki & 1], ki);
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else { } else {
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
} }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
...@@ -385,9 +393,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -385,9 +393,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
{ {
int ki = Mma_tile_p::MMAS_K; int ki = Mma_tile_p::MMAS_K;
if (!Kernel_traits::V_IN_REGS) { if (!Kernel_traits::V_IN_REGS) {
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else { } else {
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
} }
} }
...@@ -424,7 +432,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -424,7 +432,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
} }
} }
softmax.template pack<__half>(frag_p); softmax.template pack<elem_type>(frag_p);
// Store dp to smem for transpose // Store dp to smem for transpose
smem_dp.store(frag_p); smem_dp.store(frag_p);
...@@ -442,14 +450,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -442,14 +450,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
smem_kt.load(frag_kt[ki & 1], ki); smem_kt.load(frag_kt[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); // fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_dq::MMAS_K; int ki = Mma_tile_dq::MMAS_K;
fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); // fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
} }
static_assert(Gmem_tile_dq::LOOPS == 1); static_assert(Gmem_tile_dq::LOOPS == 1);
...@@ -475,7 +483,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -475,7 +483,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) { for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
#pragma unroll #pragma unroll
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
frag_s[ki][mi].template hrelu_<__half>(); frag_s[ki][mi].template hrelu_<elem_type>();
} }
} }
} }
...@@ -485,13 +493,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -485,13 +493,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
smem_dot.load(frag_dot[ki & 1], ki); smem_dot.load(frag_dot[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); fmha::gemm_cl<elem_type>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_dkv::MMAS_K; int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); fmha::gemm_cl<elem_type>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
} }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
...@@ -519,7 +527,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -519,7 +527,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
if(l < steps - 1) { if(l < steps - 1) {
gmem_do.commit(smem_do); gmem_do.commit(smem_do);
if (Is_first) { if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, __half>( dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
); );
} }
...@@ -542,13 +550,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -542,13 +550,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki); smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); fmha::gemm_cl<elem_type>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_dkv::MMAS_K; int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); fmha::gemm_cl<elem_type>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
} }
// Make sure dQ is in shared memory. // Make sure dQ is in shared memory.
...@@ -575,7 +583,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -575,7 +583,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout); dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
} }
// Output the values. // Output the values.
gmem_dq.template store<__half>(dq_out, 0); gmem_dq.template store<elem_type>(dq_out, 0);
// Move to the next part of the output. // Move to the next part of the output.
gmem_dq.move(); gmem_dq.move();
} else { } else {
...@@ -629,11 +637,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -629,11 +637,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// the total amount of shared mem? // the total amount of shared mem?
// Epilogue swizzle for dV // Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[0], tidx); Smem_tile_dv smem_dv(&smem_[0], tidx);
smem_dv.template store<__half>(acc_dv); smem_dv.template store<elem_type>(acc_dv);
// Epilogue swizzle for dK // Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx); Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
smem_dk.template store<__half>(acc_dk); smem_dk.template store<elem_type>(acc_dk);
__syncthreads(); __syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS]; uint4 dv_out[Smem_tile_dv::NUM_LDS];
......
...@@ -25,6 +25,10 @@ ...@@ -25,6 +25,10 @@
* *
******************************************************************************/ ******************************************************************************/
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "static_switch.h"
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
...@@ -35,27 +39,9 @@ __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { ...@@ -35,27 +39,9 @@ __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
template<typename Kernel_traits> template<typename Kernel_traits>
void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params, void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) { const bool configure) {
bool is_causal = launch_params.params.is_causal;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
auto kernel = launch_params.is_dropout
? (is_causal
? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, false>)
: (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, false>))
: (is_causal
? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>)
: (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>));
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
if (configure) { if (configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>; using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
...@@ -68,117 +54,133 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params, ...@@ -68,117 +54,133 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
return; return;
} }
dim3 grid(launch_params.params.b, launch_params.params.h); constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>( // Don't need smem_size_softmax_lse if we're not looping
launch_params.params); const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
FMHA_CHECK_CUDA(cudaPeekAtLastError()); BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(launch_params.params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(launch_params.return_softmax, ReturnSoftmaxConst, [&] {
auto kernel = &fmha_fprop_fp16_sm80_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst, ReturnSoftmaxConst>;
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(launch_params.params.b, launch_params.params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
});
});
} }
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) { const bool configure) {
if (launch_params.params.d == 16) { BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] {
if( launch_params.params.seqlen_k == 128 ) { using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u>; auto dprops = at::cuda::getCurrentDeviceProperties();
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); if (launch_params.params.d == 16) {
} else if( launch_params.params.seqlen_k == 256 ) { if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 32) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 64) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k >= 256 ) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 8 && dprops->minor >= 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (dprops->major == 7 && dprops->minor == 5) { } else if( launch_params.params.seqlen_k == 256 ) {
if (launch_params.is_dropout) { // Need to use the same block size as backward using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); } else {
} else { // TD [2022-05-15] 512 gives wrong results rn
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 32) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 64) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k >= 256 ) {
if (dprops->major == 8 && dprops->minor >= 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
if (launch_params.is_dropout) { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} }
} }
} } else if (launch_params.params.d == 128) {
} else if (launch_params.params.d == 128) { if( launch_params.params.seqlen_k == 128 ) {
if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} }
} }
} // if (launch_params.params.d == 64) {
// if (launch_params.params.d == 64) { // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>; // // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // }
// } // if (launch_params.params.d == 64) {
// if (launch_params.params.d == 64) { // if( launch_params.params.seqlen_k == 128 ) {
// if( launch_params.params.seqlen_k == 128 ) { // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // } else if( launch_params.params.seqlen_k >= 256 ) {
// } else if( launch_params.params.seqlen_k >= 256 ) { // if (dprops->major == 8 && dprops->minor >= 0) {
// auto dprops = at::cuda::getCurrentDeviceProperties(); // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// if (dprops->major == 8 && dprops->minor >= 0) { // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; // } else if (dprops->major == 7 && dprops->minor == 5) {
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // if (launch_params.is_dropout) { // Need to use the same block size as backward
// } else if (dprops->major == 7 && dprops->minor == 5) { // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// // if (launch_params.is_dropout) { // Need to use the same block size as backward // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; // } else {
// // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// // } else { // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; // }
// // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // }
// // } // }
// } // }
// } // if (launch_params.params.d == 128) {
// } // if( launch_params.params.seqlen_k == 128 ) {
// if (launch_params.params.d == 128) { // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// if( launch_params.params.seqlen_k == 128 ) { // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; // } else {
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// } else { // // TD [2022-06-05] Keep K in registers to reduce register spilling
// auto dprops = at::cuda::getCurrentDeviceProperties(); // // Gives about 6% speedup compared to using block size 128.
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { // using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
// // TD [2022-06-05] Keep K in registers to reduce register spilling // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// // Gives about 6% speedup compared to using block size 128. // } else { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>; // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else { // Need to use the same block size as backward // }
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; // }
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // }
// } });
// }
// }
} }
\ No newline at end of file
...@@ -72,7 +72,7 @@ struct Gemm_Q_K_base { ...@@ -72,7 +72,7 @@ struct Gemm_Q_K_base {
Smem_tile_k smem_k; Smem_tile_k smem_k;
}; };
template<typename Kernel_traits, bool K_in_regs> template<typename Kernel_traits, bool K_in_regs, typename elem_type_=__half>
struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> { struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>; using Base = Gemm_Q_K_base<Kernel_traits>;
...@@ -81,6 +81,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> { ...@@ -81,6 +81,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
using Smem_tile_k = typename Base::Smem_tile_k; using Smem_tile_k = typename Base::Smem_tile_k;
using Fragment_k = typename Base::Fragment_k; using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p; using Mma_tile_p = typename Base::Mma_tile_p;
using elem_type = elem_type_;
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V; static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
// If V is stored in shared memory, we can't load K using the same shared memory. // If V is stored in shared memory, we can't load K using the same shared memory.
...@@ -115,12 +116,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> { ...@@ -115,12 +116,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki); Base::smem_q.load(Base::frag_q[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_p::MMAS_K; int ki = Mma_tile_p::MMAS_K;
fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
} }
} }
...@@ -132,8 +133,8 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> { ...@@ -132,8 +133,8 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
}; };
template<typename Kernel_traits> template<typename Kernel_traits, typename elem_type_>
struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> { struct Gemm_Q_K<Kernel_traits, false, elem_type_> : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>; using Base = Gemm_Q_K_base<Kernel_traits>;
using Smem_tile_o = typename Base::Smem_tile_o; using Smem_tile_o = typename Base::Smem_tile_o;
using Smem_tile_q = typename Base::Smem_tile_q; using Smem_tile_q = typename Base::Smem_tile_q;
...@@ -141,6 +142,7 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> { ...@@ -141,6 +142,7 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
using Smem_tile_v = typename Kernel_traits::Smem_tile_v; using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
using Fragment_k = typename Base::Fragment_k; using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p; using Mma_tile_p = typename Base::Mma_tile_p;
using elem_type = elem_type_;
Fragment_k frag_k[2][Mma_tile_p::MMAS_N]; Fragment_k frag_k[2][Mma_tile_p::MMAS_N];
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V; static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
...@@ -175,12 +177,12 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> { ...@@ -175,12 +177,12 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
Base::smem_q.load(Base::frag_q[ki & 1], ki); Base::smem_q.load(Base::frag_q[ki & 1], ki);
Base::smem_k.load(frag_k[ki & 1], ki); Base::smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_p::MMAS_K; int ki = Mma_tile_p::MMAS_K;
fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
} }
} }
...@@ -197,6 +199,13 @@ constexpr size_t get_dynamic_smem_size(){ ...@@ -197,6 +199,13 @@ constexpr size_t get_dynamic_smem_size(){
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type;
#else
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
assert(is_fp16_type);
using elem_type = __half;
#endif
// The description of the CTA tile for the 1st batched GEMM. // The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p; using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
...@@ -231,7 +240,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -231,7 +240,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum; using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum;
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>; using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS, elem_type>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>; using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
...@@ -363,6 +372,10 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -363,6 +372,10 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Do this part of P = Q * K^T. // Do this part of P = Q * K^T.
gemm_q_k(acc_p); gemm_q_k(acc_p);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1));
// }
uint4 out[Gmem_tile_o::STGS_PER_LOOP]; uint4 out[Gmem_tile_o::STGS_PER_LOOP];
if (!Is_first) { gmem_o_tmp.load(out, 0); } if (!Is_first) { gmem_o_tmp.load(out, 0); }
...@@ -466,7 +479,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -466,7 +479,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
softmax.template pack<__half>(frag_p); softmax.template pack<elem_type>(frag_p);
if (Return_softmax) { if (Return_softmax) {
gmem_s.store(frag_p, mask); gmem_s.store(frag_p, mask);
gmem_s.move(); gmem_s.move();
...@@ -482,7 +495,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -482,7 +495,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll #pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
frag_p[ki][mi].template hrelu_<__half>(); frag_p[ki][mi].template hrelu_<elem_type>();
} }
} }
} }
...@@ -494,7 +507,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -494,7 +507,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Do this part of O = P^T * V^T. // Do this part of O = P^T * V^T.
#pragma unroll #pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm_cl<__half>(acc_o, frag_p[ki], frag_v[ki]); fmha::gemm_cl<elem_type>(acc_o, frag_p[ki], frag_v[ki]);
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { // if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki])); // float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki])); // float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
...@@ -605,7 +618,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -605,7 +618,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Output the values. // Output the values.
if (is_final_write) { if (is_final_write) {
gmem_o.template store<__half>(out, 0); gmem_o.template store<elem_type>(out, 0);
gmem_o.move(); gmem_o.move();
} else { } else {
gmem_o_tmp.store(out, 0); gmem_o_tmp.store(out, 0);
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h>
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -50,7 +51,7 @@ ...@@ -50,7 +51,7 @@
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
enum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 }; enum Data_type { DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -60,6 +61,11 @@ static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) { ...@@ -60,6 +61,11 @@ static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
uint16_t h = reinterpret_cast<const uint16_t &>( x ); uint16_t h = reinterpret_cast<const uint16_t &>( x );
ushort2 h2 = { h, h }; ushort2 h2 = { h, h };
alpha = reinterpret_cast<const uint32_t &>( h2 ); alpha = reinterpret_cast<const uint32_t &>( h2 );
} else if( dtype == DATA_TYPE_BF16 ) {
__nv_bfloat16 x = __float2bfloat16( norm );
uint16_t h = reinterpret_cast<const uint16_t &>( x );
ushort2 h2 = { h, h };
alpha = reinterpret_cast<const uint32_t &>( h2 );
} else if( dtype == DATA_TYPE_FP32 ) { } else if( dtype == DATA_TYPE_FP32 ) {
alpha = reinterpret_cast<const uint32_t &>( norm ); alpha = reinterpret_cast<const uint32_t &>( norm );
} else if( dtype == DATA_TYPE_INT32 ) { } else if( dtype == DATA_TYPE_INT32 ) {
...@@ -78,6 +84,8 @@ static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) { ...@@ -78,6 +84,8 @@ static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {
return n * 4; return n * 4;
case DATA_TYPE_FP16: case DATA_TYPE_FP16:
return n * 2; return n * 2;
case DATA_TYPE_BF16:
return n * 2;
case DATA_TYPE_INT32: case DATA_TYPE_INT32:
return n * 4; return n * 4;
case DATA_TYPE_INT8: case DATA_TYPE_INT8:
......
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment