Unverified Commit d562aa63 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Sync with FA v2.6.0 to support soft capping (#13)

parent 12375706
...@@ -4,8 +4,8 @@ recursive-include csrc *.cuh ...@@ -4,8 +4,8 @@ recursive-include csrc *.cuh
recursive-include csrc *.cpp recursive-include csrc *.cpp
recursive-include csrc *.hpp recursive-include csrc *.hpp
recursive-include flash_attn *.cu recursive-include vllm_flash_attn *.cu
recursive-include flash_attn *.h recursive-include vllm_flash_attn *.h
recursive-include flash_attn *.cuh recursive-include vllm_flash_attn *.cuh
recursive-include flash_attn *.cpp recursive-include vllm_flash_attn *.cpp
recursive-include flash_attn *.hpp recursive-include vllm_flash_attn *.hpp
...@@ -314,6 +314,11 @@ Implement deterministic backward pass. Thanks to engineers from [Meituan](www.me ...@@ -314,6 +314,11 @@ Implement deterministic backward pass. Thanks to engineers from [Meituan](www.me
Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)). Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
Thanks to @beginlner for this contribution. Thanks to @beginlner for this contribution.
### 2.6: Softcapping.
Support attention with softcapping, as used in Gemma-2 and Grok models.
Thanks to @Narsil for this contribution.
## Performance ## Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
...@@ -400,12 +405,13 @@ If you use this codebase, or otherwise found our work valuable, please cite: ...@@ -400,12 +405,13 @@ If you use this codebase, or otherwise found our work valuable, please cite:
@inproceedings{dao2022flashattention, @inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems}, booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2022} year={2022}
} }
@article{dao2023flashattention2, @inproceedings{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author={Dao, Tri}, author={Dao, Tri},
year={2023} booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
} }
``` ```
...@@ -43,7 +43,9 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -43,7 +43,9 @@ void set_params_fprop(Flash_fwd_params &params,
float softmax_scale, float softmax_scale,
int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
bool seqlenq_ngroups_swapped=false) { const float softcap,
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false) {
// Reset the parameters // Reset the parameters
params = {}; params = {};
...@@ -99,8 +101,19 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -99,8 +101,19 @@ void set_params_fprop(Flash_fwd_params &params,
params.d_rounded = d_rounded; params.d_rounded = d_rounded;
// Set the different scale values. // Set the different scale values.
params.scale_softmax = softmax_scale; #ifdef FLASHATTENTION_DISABLE_SOFTCAP
params.scale_softmax_log2 = softmax_scale * M_LOG2E; TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
#endif
if (softcap > 0.0) {
params.softcap = softmax_scale / softcap;
params.scale_softmax = softcap;
params.scale_softmax_log2 = softcap * M_LOG2E;
} else{
// Remove potential NaN
params.softcap = 0.0;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
}
// Set this to probability of keeping an element to simplify things. // Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout; params.p_dropout = 1.f - p_dropout;
...@@ -135,16 +148,21 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -135,16 +148,21 @@ void set_params_fprop(Flash_fwd_params &params,
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
#endif #endif
params.unpadded_lse = unpadded_lse;
params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
} }
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) { void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] { FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] { HEADDIM_SWITCH(params.d, [&] {
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_mha_fwd_<elem_type, kHeadDim>(params, stream); if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
} else { run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream); } else {
} run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
}
});
}); });
}); });
} }
...@@ -248,6 +266,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -248,6 +266,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
bool is_causal, bool is_causal,
int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const float softcap,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -286,6 +305,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -286,6 +305,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; }
...@@ -369,7 +390,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -369,7 +390,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
p_dropout, p_dropout,
softmax_scale, softmax_scale,
window_size_left, window_size_left,
window_size_right); window_size_right,
softcap
);
set_params_splitkv(params, batch_size, num_heads, set_params_splitkv(params, batch_size, num_heads,
...@@ -437,6 +460,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -437,6 +460,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
bool is_causal, bool is_causal,
int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const float softcap,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -485,6 +509,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -485,6 +509,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int head_size_og = sizes[2]; const int head_size_og = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1); const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0); const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1); const int page_block_size = !paged_KV ? 1 : k.size(1);
...@@ -553,7 +579,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -553,7 +579,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out); CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og); CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
...@@ -574,8 +599,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -574,8 +599,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
at::cuda::CUDAGuard device_guard{(char)q.get_device()}; at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options(); auto opts = q.options();
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
at::Tensor p; at::Tensor p;
// Only return softmax if there's dropout to reduce compilation time // Only return softmax if there's dropout to reduce compilation time
if (return_softmax) { if (return_softmax) {
...@@ -606,7 +630,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -606,7 +630,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
softmax_scale, softmax_scale,
window_size_left, window_size_left,
window_size_right, window_size_right,
seqlenq_ngroups_swapped); softcap,
seqlenq_ngroups_swapped,
/*unpadded_lse*/true);
params.total_q = total_q;
if (paged_KV) { if (paged_KV) {
params.block_table = block_table.data_ptr<int>(); params.block_table = block_table.data_ptr<int>();
...@@ -662,7 +689,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -662,7 +689,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
out = out.reshape(size_before).transpose(1, 2).reshape(size_after); out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after); out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1}); softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
} }
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
...@@ -685,6 +712,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -685,6 +712,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
bool is_causal, bool is_causal,
int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits int num_splits
) { ) {
...@@ -826,7 +854,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -826,7 +854,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
/*p_dropout=*/0.f, /*p_dropout=*/0.f,
softmax_scale, softmax_scale,
window_size_left, window_size_left,
window_size_right); window_size_right,
softcap
);
at::Tensor k, v, k_padded, v_padded; at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) { if (k_.has_value()) {
......
...@@ -31,7 +31,7 @@ struct Alibi { ...@@ -31,7 +31,7 @@ struct Alibi {
const int col_idx_offset_, const int col_idx_offset_,
const int row_idx_offset, const int row_idx_offset,
const int warp_row_stride) { const int warp_row_stride) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
......
...@@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ softmax_lseaccum_ptr; void * __restrict__ softmax_lseaccum_ptr;
// The dimensions. // The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
// The scaling factors for the kernel. // The scaling factors for the kernel.
float scale_softmax; float scale_softmax;
...@@ -118,6 +118,7 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -118,6 +118,7 @@ struct Flash_fwd_params : public Qkv_params {
// Local window size // Local window size
int window_size_left, window_size_right; int window_size_left, window_size_right;
float softcap;
// Random state. // Random state.
at::PhiloxCudaState philox_args; at::PhiloxCudaState philox_args;
...@@ -138,6 +139,9 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -138,6 +139,9 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ alibi_slopes_ptr; void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride; index_t alibi_slopes_batch_stride;
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -184,7 +188,7 @@ struct Flash_bwd_params : public Flash_fwd_params { ...@@ -184,7 +188,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream); template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream); run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::half_t>(params, stream); run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t>(params, stream); run_mha_fwd_hdim192<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim224<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::half_t, true>(params, stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment