"vscode:/vscode.git/clone" did not exist on "fce0a57dec5635022d4170ba4a15e32a432b9cc7"
Commit ed4959b2 authored by Tri Dao's avatar Tri Dao
Browse files

Change inline to __forceinline__, use __grid_constant__ param

parent 6f706eff
...@@ -19,7 +19,7 @@ struct Alibi { ...@@ -19,7 +19,7 @@ struct Alibi {
const float alibi_slope; const float alibi_slope;
const int max_seqlen_k, max_seqlen_q; const int max_seqlen_k, max_seqlen_q;
inline __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
: alibi_slope(alibi_slope) : alibi_slope(alibi_slope)
, max_seqlen_k(max_seqlen_k) , max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q) { , max_seqlen_q(max_seqlen_q) {
...@@ -27,7 +27,7 @@ struct Alibi { ...@@ -27,7 +27,7 @@ struct Alibi {
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor, __forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_, const int col_idx_offset_,
const int row_idx_offset, const int row_idx_offset,
const int warp_row_stride) { const int warp_row_stride) {
......
...@@ -24,12 +24,12 @@ struct BlockInfo { ...@@ -24,12 +24,12 @@ struct BlockInfo {
} }
template <typename index_t> template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
} }
template <typename index_t> template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
} }
......
...@@ -14,7 +14,7 @@ struct Dropout { ...@@ -14,7 +14,7 @@ struct Dropout {
const unsigned long long seed, offset; const unsigned long long seed, offset;
const uint8_t p_dropout_in_uint8_t; const uint8_t p_dropout_in_uint8_t;
inline __device__ Dropout(const unsigned long long seed, const unsigned long long offset, __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
const uint8_t p_dropout_in_uint8_t, const uint8_t p_dropout_in_uint8_t,
const int bid, const int hid, const int tid, const int nheads) const int bid, const int hid, const int tid, const int nheads)
: seed(seed) : seed(seed)
...@@ -23,7 +23,7 @@ struct Dropout { ...@@ -23,7 +23,7 @@ struct Dropout {
} }
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout> template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_, __forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) { int block_row_start, int block_col_start, int block_row_stride) {
// tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) // tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_rowcol_dropout(tensor_.layout())); Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_rowcol_dropout(tensor_.layout()));
......
...@@ -448,7 +448,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -448,7 +448,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
clear(acc_dv); clear(acc_dv);
clear(acc_dk); clear(acc_dk);
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
for (; m_block >= m_block_min; --m_block) { for (; m_block >= m_block_min; --m_block) {
......
...@@ -12,33 +12,33 @@ ...@@ -12,33 +12,33 @@
#include "flash_bwd_kernel.h" #include "flash_bwd_kernel.h"
template<bool Clear_dQaccum=true, typename Kernel_traits> template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params); flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) { __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
flash::clear_dKVaccum<Kernel_traits>(params); flash::clear_dKVaccum<Kernel_traits>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dq_dk_dv_loop_kernel(__grid_constant__ const Flash_bwd_params params) {
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params); flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(__grid_constant__ const Flash_bwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params); flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params, const int nsplits) { __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
flash::convert_dQ<Kernel_traits>(params, nsplits); flash::convert_dQ<Kernel_traits>(params, nsplits);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) { __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
flash::convert_dKV<Kernel_traits>(params); flash::convert_dKV<Kernel_traits>(params);
} }
......
...@@ -11,18 +11,18 @@ ...@@ -11,18 +11,18 @@
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) { __global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params); flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
} }
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { __global__ void flash_fwd_splitkv_kernel(__grid_constant__ const Flash_fwd_params params) {
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params); flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
} }
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K> template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { __global__ void flash_fwd_splitkv_combine_kernel(__grid_constant__ const Flash_fwd_params params) {
static_assert(Log_max_splits >= 1); static_assert(Log_max_splits >= 1);
flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params); flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
} }
......
...@@ -11,7 +11,7 @@ namespace flash { ...@@ -11,7 +11,7 @@ namespace flash {
using namespace cute; using namespace cute;
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k, __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) { const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
...@@ -35,7 +35,7 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_ ...@@ -35,7 +35,7 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
} }
template <bool HasWSLeft=true, typename Engine, typename Layout> template <bool HasWSLeft=true, typename Engine, typename Layout>
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride, const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) { const int window_size_left, const int window_size_right) {
...@@ -72,7 +72,7 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in ...@@ -72,7 +72,7 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
} }
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, __forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) { const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
...@@ -81,7 +81,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i ...@@ -81,7 +81,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
} }
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx( __forceinline__ __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{ {
......
...@@ -9,7 +9,7 @@ struct ull2 { ...@@ -9,7 +9,7 @@ struct ull2 {
unsigned long long y; unsigned long long y;
}; };
inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { __forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
uint2 *res; uint2 *res;
unsigned long long tmp; unsigned long long tmp;
asm ("mul.wide.u32 %0, %1, %2;\n\t" asm ("mul.wide.u32 %0, %1, %2;\n\t"
...@@ -19,7 +19,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { ...@@ -19,7 +19,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
return *res; return *res;
} }
inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { __forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
constexpr unsigned long kPhiloxSA = 0xD2511F53; constexpr unsigned long kPhiloxSA = 0xD2511F53;
constexpr unsigned long kPhiloxSB = 0xCD9E8D57; constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
...@@ -28,7 +28,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { ...@@ -28,7 +28,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
return ret; return ret;
} }
inline __device__ uint4 philox(unsigned long long seed, __forceinline__ __device__ uint4 philox(unsigned long long seed,
unsigned long long subsequence, unsigned long long subsequence,
unsigned long long offset) { unsigned long long offset) {
constexpr unsigned long kPhilox10A = 0x9E3779B9; constexpr unsigned long kPhilox10A = 0x9E3779B9;
......
...@@ -20,7 +20,7 @@ using namespace cute; ...@@ -20,7 +20,7 @@ using namespace cute;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) { __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
...@@ -35,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Te ...@@ -35,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Te
} }
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) { __device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src)); CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll #pragma unroll
for (int i = 0; i < size(dst); i++){ for (int i = 0; i < size(dst); i++){
...@@ -44,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Eng ...@@ -44,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Eng
} }
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) { __device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op); thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op); quad_allreduce_(summary, summary, op);
} }
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op; MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op); reduce_<zero_init>(tensor, max, max_op);
} }
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1> template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){ __device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op; SumOp<float> sum_op;
reduce_(tensor, sum, sum_op); reduce_(tensor, sum, sum_op);
} }
// Apply the exp to all the elements. // Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) { __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
...@@ -85,7 +85,7 @@ inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor ...@@ -85,7 +85,7 @@ inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor
// Apply the exp to all the elements. // Apply the exp to all the elements.
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) { __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
...@@ -123,10 +123,10 @@ struct Softmax { ...@@ -123,10 +123,10 @@ struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{})); using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum; TensorT row_max, row_sum;
inline __device__ Softmax() {}; __forceinline__ __device__ Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1> template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
inline __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows); static_assert(decltype(size<0>(scores))::value == kNRows);
...@@ -160,7 +160,7 @@ struct Softmax { ...@@ -160,7 +160,7 @@ struct Softmax {
}; };
template<bool Is_dropout=false, bool Split=false, typename Tensor0> template<bool Is_dropout=false, bool Split=false, typename Tensor0>
inline __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
TensorT lse = make_fragment_like(row_sum); TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
......
...@@ -29,10 +29,10 @@ namespace flash { ...@@ -29,10 +29,10 @@ namespace flash {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
inline __device__ uint32_t relu2(const uint32_t x); __forceinline__ __device__ uint32_t relu2(const uint32_t x);
template<> template<>
inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) { __forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
uint32_t res; uint32_t res;
const uint32_t zero = 0u; const uint32_t zero = 0u;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
...@@ -50,7 +50,7 @@ inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) { ...@@ -50,7 +50,7 @@ inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<> template<>
inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) { __forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
uint32_t res; uint32_t res;
const uint32_t zero = 0u; const uint32_t zero = 0u;
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
...@@ -63,10 +63,10 @@ inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) { ...@@ -63,10 +63,10 @@ inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<typename T> template<typename T>
inline __device__ uint32_t convert_relu2(const float2 x); __forceinline__ __device__ uint32_t convert_relu2(const float2 x);
template<> template<>
inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) { __forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
uint32_t res; uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x); const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y); const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
...@@ -75,7 +75,7 @@ inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) { ...@@ -75,7 +75,7 @@ inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
} }
template<> template<>
inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) { __forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
uint32_t res; uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x); const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y); const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
...@@ -89,20 +89,20 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) { ...@@ -89,20 +89,20 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
template<typename T> template<typename T>
struct MaxOp { struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
}; };
template <> template <>
struct MaxOp<float> { struct MaxOp<float> {
// This is slightly faster // This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
struct SumOp { struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; } __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -111,7 +111,7 @@ template<int THREADS> ...@@ -111,7 +111,7 @@ template<int THREADS>
struct Allreduce { struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator> template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) { static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2; constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op); return Allreduce<OFFSET>::run(x, op);
...@@ -123,7 +123,7 @@ struct Allreduce { ...@@ -123,7 +123,7 @@ struct Allreduce {
template<> template<>
struct Allreduce<2> { struct Allreduce<2> {
template<typename T, typename Operator> template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) { static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x; return x;
} }
...@@ -135,7 +135,7 @@ template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename ...@@ -135,7 +135,7 @@ template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename
typename Tensor2, typename Tensor3, typename Tensor4, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopyA, typename TiledCopyB, typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB> typename ThrCopyA, typename ThrCopyB>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, __forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma, Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
...@@ -162,7 +162,7 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 ...@@ -162,7 +162,7 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy> typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, __forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) { ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
...@@ -184,7 +184,7 @@ inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tenso ...@@ -184,7 +184,7 @@ inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tenso
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template<typename Layout> template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
...@@ -196,7 +196,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { ...@@ -196,7 +196,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout> template<typename MMA_traits, typename Layout>
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { __forceinline__ __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
using X = Underscore; using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
...@@ -213,7 +213,7 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { ...@@ -213,7 +213,7 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
template<typename Layout> template<typename Layout>
inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) { __forceinline__ __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) {
using X = Underscore; using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
...@@ -226,7 +226,7 @@ inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) { ...@@ -226,7 +226,7 @@ inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout> template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) { __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type; using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value; constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
...@@ -238,7 +238,7 @@ inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) { ...@@ -238,7 +238,7 @@ inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void relu_(Tensor<Engine, Layout> &tensor) { __forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
constexpr int numel = decltype(size(tensor))::value; constexpr int numel = decltype(size(tensor))::value;
static_assert(numel % 2 == 0); static_assert(numel % 2 == 0);
using value_t = typename Engine::value_type; using value_t = typename Engine::value_type;
...@@ -254,7 +254,7 @@ inline __device__ void relu_(Tensor<Engine, Layout> &tensor) { ...@@ -254,7 +254,7 @@ inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
template <typename To_type, typename Engine, typename Layout> template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) { __forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type; using From_type = typename Engine::value_type;
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>); static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
static_assert(std::is_same_v<float, From_type>); static_assert(std::is_same_v<float, From_type>);
...@@ -296,7 +296,7 @@ void cp_async_wait() { ...@@ -296,7 +296,7 @@ void cp_async_wait() {
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S, __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) { Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
...@@ -365,7 +365,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const ...@@ -365,7 +365,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
template <bool Is_even_K=true, template <bool Is_even_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S, __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, Tensor<Engine3, Layout3> const &predicate_K,
const int max_MN=0, const int min_MN=0) { const int max_MN=0, const int min_MN=0) {
...@@ -395,7 +395,7 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S, ...@@ -395,7 +395,7 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
template <bool Is_even_K=true, bool Clear_OOB_K=true, template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S, __forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos, Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin, Tensor<Engine2, Layout2> const &Sin,
...@@ -458,7 +458,7 @@ inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S ...@@ -458,7 +458,7 @@ inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S
template <bool Is_even_K=true, bool Clear_OOB_K=true, template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S, __forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos, Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin, Tensor<Engine2, Layout2> const &Sin,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment