Commit 34e67b1e authored by zhangshao's avatar zhangshao
Browse files

first commit

parents
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
// #include "philox_unpack.cuh" // For at::cuda::philox::unpack
#include <cutlass/numeric_types.h>
// #include "namespace_config.h"
// #include "hardware_info.h"
#include "flash_sparse.h"
#include "static_switch.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
static const bool print_param = get_env_("FLASH_ATTENTION_PRINT_PARAM");
//
// Bit hacky but for now hook into the existing set_params_fprop,
// set_params_splitkv, and set_params_alibi in flash_api.cpp
//
void set_params_fprop(Flash_fwd_params &params,
// sizes
const size_t b,
const size_t seqlen_q,
const size_t seqlen_k,
const size_t seqlen_q_rounded,
const size_t seqlen_k_rounded,
const size_t h,
const size_t h_k,
const size_t d,
const size_t d_rounded,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
at::Tensor out,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *seqused_k,
void *p_d,
void *softmax_lse_d,
float p_dropout,
float softmax_scale,
int window_size_left,
int window_size_right,
const float softcap,
bool is_bhsd = false,
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false,
int d_v=0,
int d_v_rounded=0,
bool is_vllm_kvcache=false
);
std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
const int head_size_rounded, const float p_dropout,
const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts);
void set_params_alibi(Flash_fwd_params &params, std::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads);
///////////////////////////////////////////////////////////////////////////////
void set_params_fprop_sparse(Flash_fwd_params_sparse &params,
// sizes
const size_t b,
const size_t seqlen_q,
const size_t seqlen_k,
const size_t seqlen_q_rounded,
const size_t seqlen_k_rounded,
const size_t h,
const size_t h_k,
const size_t d,
const size_t d_rounded,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor block_count,
const at::Tensor block_offset,
const at::Tensor column_count,
const at::Tensor column_index,
at::Tensor out,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *seqused_k,
void *p_d,
void *softmax_lse_d,
float p_dropout,
float softmax_scale,
int64_t window_size_left,
int64_t window_size_right,
const float softcap,
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false) {
set_params_fprop(params,
b,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
h, h_k,
d, d_rounded,
q, k, v, out,
cu_seqlens_q_d,
cu_seqlens_k_d,
seqused_k,
p_d,
softmax_lse_d,
p_dropout,
softmax_scale,
window_size_left, // window_size_left
window_size_right, // window_size_right
softcap,
false,
seqlenq_ngroups_swapped,
unpadded_lse
);
params.block_count = block_count.const_data_ptr<int>();
params.block_offset = block_offset.const_data_ptr<int>();
params.column_count = column_count.const_data_ptr<int>();
params.column_index = column_index.const_data_ptr<int>();
TORCH_CHECK(block_count.size(2) == block_offset.size(2));
TORCH_CHECK(column_index.size(2) == block_offset.size(2));
TORCH_CHECK(column_count.size(2) == column_index.size(2));
params.NUM_ROWS = block_count.size(2);
int BLOCK_M = (seqlen_q <= 2048) ? 64 : 128;
int expected_num_rows = (seqlen_q + BLOCK_M - 1) / BLOCK_M;
TORCH_CHECK(params.NUM_ROWS == expected_num_rows,
"NUM_ROWS mismatch: got ", params.NUM_ROWS, " but expected ", expected_num_rows,
" (seqlen_q=", seqlen_q, ", BLOCK_M=", BLOCK_M, ")");
params.NNZ_S = block_offset.size(3);
params.NNZ_V = column_index.size(3);
}
void run_mha_fwd_sparse(Flash_fwd_params_sparse &params, cudaStream_t stream, bool force_split_kernel=false) {
TORCH_CHECK(params.num_splits <= 1 && !force_split_kernel, "run_mha_fwd_sparse does not support splitkv.");
TORCH_CHECK(params.d == 128, "run_mha_fwd_sparse only supports headdim=128 for now to keep binary small.");
FP16_SWITCH(!params.is_bf16, [&] {
constexpr static int kHeadDim = 128;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_mha_fwd_sparse_<elem_type, kHeadDim, Is_causal>(params, stream);
});
});
}
#if 1
void run_mha_fwd_sparse_sla(Flash_fwd_params_sparse &params, cudaStream_t stream, bool force_split_kernel=false) {
TORCH_CHECK(params.num_splits <= 1 && !force_split_kernel, "run_mha_fwd_sparse does not support splitkv.");
if (params.is_fp8) {
TORCH_CHECK(params.d == 128, "run_mha_fwd_sparse_sla only supports headdim=128 for now to keep binary small.");
constexpr static int kHeadDim = 128;
using elem_type = cutlass::float_e4m3_t;
run_mha_fwd_sparse_sla_fp8_<elem_type, kHeadDim>(params, stream);
return;
}
TORCH_CHECK(params.d == 128 or params.d == 64, "run_mha_fwd_sparse_sla only supports headdim=(64, 128) for now to keep binary small.");
HEADDIM_SWITCH_SLA(params.d, [&] {
FP16_SWITCH(!params.is_bf16, [&] {
run_mha_fwd_sparse_sla_<elem_type, kHeadDim>(params, stream);
});
});
}
#endif
extern "C"
std::vector<at::Tensor>
mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &block_count,
const at::Tensor &block_offset,
const at::Tensor &column_count,
const at::Tensor &column_index,
const std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const double p_dropout,
const double softmax_scale,
bool is_causal,
const double softcap,
const bool return_softmax,
std::optional<at::Generator> gen_,
bool is_sla = false,
const double pv_threshold = 50.0, // Dynamic PV skip threshold
const bool enable_dynamic_skip = true // Enable dynamic PV skip optimization
) {
// auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
// bool is_sm8x_min = cc_major >= 8;
// TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn,
"FlashAttention only support fp16 and bf16 data type");
// if (q_dtype == torch::kBFloat16) {
// TORCH_CHECK(is_sm8x_min, "bfloat16 is only supported on Ampere GPUs or newer");
// }
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
// causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
int64_t window_size_left = -1;
int64_t window_size_right = -1;
if (is_causal) { window_size_right = 0; }
if(print_param){
printf("mha_fwd_sparse fa input size bshd=(%d,%d,%d,%d),p_dropout=%.3f,softmax_scale=%.3f,is_causal=%d,window_size_left=%d,window_size_right=%d,softcap=%f,return_softmax=%d,is_bhsd=%d\n",
(int)sizes[0],(int)sizes[1],(int)sizes[2],(int)sizes[3],p_dropout,softmax_scale,(int)is_causal,window_size_left,window_size_right,softcap,return_softmax,false);
}
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) {
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
q_padded = q;
k_padded = k;
v_padded = v;
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor p;
// Only return softmax if there's dropout to reduce compilation time
if (return_softmax) {
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
}
Flash_fwd_params_sparse params;
set_params_fprop_sparse(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q_padded, k_padded, v_padded,
block_count, block_offset,
column_count, column_index,
out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
window_size_left,
window_size_right,
softcap
);
// Set dynamic PV skip parameters
params.pv_threshold = static_cast<float>(pv_threshold);
params.enable_dynamic_skip = enable_dynamic_skip;
auto dprops = at::cuda::getCurrentDeviceProperties();
// Keep references to these tensors to extend their lifetime
at::Tensor softmax_lse_accum, out_accum;
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/ 1, dprops, opts);
// NOTE(woosuk): Commented out because they are not used in inference.
// // number of times random will be generated per thread, to offset philox counter in thc random
// // state
// // We use a custom RNG that increases the offset by batch_size * nheads * 32.
// int64_t counter_offset = params.b * params.h * 32;
// auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
// auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// // Forward kernel will populate memory with the seed and offset.
// params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
// if (p_dropout > 0.0) {
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
// gen_, at::cuda::detail::getDefaultCUDAGenerator());
// // See Note [Acquire lock when using random generators]
// std::lock_guard<std::mutex> lock(gen->mutex_);
// params.philox_args = gen->philox_cuda_state(counter_offset);
// }
// for alibi_slopes_ cast away constness that was added for torch library
// compatibility, needs to be cast away to maintain compatibility with
// upstream
set_params_alibi(params,
const_cast<std::optional<at::Tensor> &>(alibi_slopes_),
batch_size, num_heads);
if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (is_sla) {
run_mha_fwd_sparse_sla(params, stream);
} else
run_mha_fwd_sparse(params, stream);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
at::Tensor out_padded = out;
if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
}
return {out, softmax_lse};
}
extern "C"
std::vector<at::Tensor>
mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
const at::Tensor &block_count,
const at::Tensor &block_offset,
const at::Tensor &column_count,
const at::Tensor &column_index,
const std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
const std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int64_t max_seqlen_q,
const int64_t max_seqlen_k,
const double p_dropout,
const double softmax_scale,
const bool zero_tensors,
bool is_causal,
const double softcap,
const bool return_softmax,
std::optional<at::Generator> gen_,
const double pv_threshold = 50.0, // Dynamic PV skip threshold
const bool enable_dynamic_skip = true // Enable dynamic PV skip optimization
) {
// auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
// bool is_sm8x_min = cc_major >= 8;
// TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
// if (q_dtype == torch::kBFloat16) {
// TORCH_CHECK(is_sm8x_min, "bfloat16 is only supported on Ampere GPUs or newer");
// }
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);
at::Tensor block_table;
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
const auto sizes = q.sizes();
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int num_heads_k = k.size(1);
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
int64_t window_size_left = -1;
int64_t window_size_right = -1;
if (is_causal) { window_size_right = 0; }
if(print_param){
printf("mha_varlen_fwd_sparse fa input size bshd=(%d,%d,%d,%d),p_dropout=%.3f,softmax_scale=%.3f,is_causal=%d,window_size_left=%d,window_size_right=%d,softcap=%f,return_softmax=%d,is_bhsd=%d\n",
(int)sizes[0],(int)sizes[1],(int)sizes[2],(int)sizes[3],p_dropout,softmax_scale,(int)is_causal,window_size_left,window_size_right,softcap,return_softmax,false);
}
void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
const int total_q = q.sizes()[0];
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
if (seqused_k.has_value()){
auto seqused_k_ = seqused_k.value();
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
CHECK_SHAPE(seqused_k_, batch_size);
}
at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) {
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
q_padded = q;
k_padded = k;
v_padded = v;
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
at::Tensor p;
// Only return softmax if there's dropout to reduce compilation time
if (return_softmax) {
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
}
if (zero_tensors) {
out.zero_();
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
if (return_softmax) {p.zero_();}
}
Flash_fwd_params_sparse params;
set_params_fprop_sparse(params,
batch_size,
max_seqlen_q, max_seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q_padded, k_padded, v_padded,
block_count, block_offset,
column_count, column_index,
out,
cu_seqlens_q_d,
cu_seqlens_k.data_ptr(),
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
window_size_left,
window_size_right,
softcap
);
params.total_q = total_q;
// Set dynamic PV skip parameters
params.pv_threshold = static_cast<float>(pv_threshold);
params.enable_dynamic_skip = enable_dynamic_skip;
// Keep references to these tensors to extend their lifetime
at::Tensor softmax_lse_accum, out_accum;
// NOTE(woosuk): Commented out because they are not used in inference.
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
// int64_t counter_offset = params.b * params.h * 32;
// auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
// auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// // Forward kernel will populate memory with the seed and offset.
// params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
// if (p_dropout > 0.0) {
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
// gen_, at::cuda::detail::getDefaultCUDAGenerator());
// // See Note [Acquire lock when using random generators]
// std::lock_guard<std::mutex> lock(gen->mutex_);
// params.philox_args = gen->philox_cuda_state(counter_offset);
// }
// for alibi_slopes_ cast away constness that was added for torch library
// compatibility, needs to be cast away to maintain compatibility with
// upstream
set_params_alibi(params,
const_cast<std::optional<at::Tensor> &>(alibi_slopes_),
batch_size, num_heads);
if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd_sparse(params, stream);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
// at::Tensor out_padded = out;
if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
}
return {out, softmax_lse};
}
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.doc() = "FlashAttention";
// m.def("fwd_sparse", &mha_fwd_sparse, "Forward sparse pass");
// m.def("varlen_fwd_sparse", &mha_varlen_fwd_sparse, "Forward pass sparse (variable length)");
// }
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_causal>
struct Alibi {
const float alibi_slope;
const int max_seqlen_k, max_seqlen_q;
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
: alibi_slope(alibi_slope)
, max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q) {
};
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + lane_id / 16;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 4;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
tensor(mi, make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_alibi_continuous(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + (lane_id / 16) * 4;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 1;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
tensor(mi, make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_alibi_trans(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + (lane_id / 16) * 4;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 1;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx = row_idx_offset + mi * warp_row_stride;
tensor(mi, make_coord(j, nj)) += alibi_slope * row_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
tensor(mi, make_coord(j, nj)) -= alibi_slope * abs(col_idx + max_seqlen_k - max_seqlen_q - row_idx);
}
}
}
}
}
};
} // namespace flash
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
, seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb, const bool padding_mask)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(params.padding_mask[bidb])
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
, seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
, actual_seqlen_k(params.padding_mask[bidb])
{
}
template <typename index_t>
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int leftpad_k;
const int seqlen_k_cache;
const int actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include "philox.cuh"
#include "utils.h"
namespace flash {
struct Dropout {
const unsigned long long seed, offset;
const uint8_t p_dropout_in_uint8_t;
__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
const uint8_t p_dropout_in_uint8_t,
const int bid, const int hid, const int tid, const int nheads)
: seed(seed)
, offset(offset + (bid * nheads + hid) * 32)
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
}
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
if constexpr(encode_dropout_in_sign_bit) {
return keep ? val : -val;
} else {
return keep ? val : T(0);
}
};
#if 1
#pragma unroll
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
uint2 rowcol = make_uint2(block_row_start, block_col_start);
#pragma unroll
for (int n = 0; n < size<2>(tensor); ++n, ++rowcol.y) {
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// 16位类型的特殊实现:我们将阈值复制到32位值的低16位和高16位,然后使用f16x2比较指令来获取掩码。
// 掩码的低16位将是0xffff或0x0000,高16位也将是0xffff或0x0000,这取决于随机值是否小于阈值。
// 然后,我们在掩码和原始32位值之间进行位与运算。
// 我们利用了浮点比较等同于整数比较的事实,因为我们比较的是其最高8位为零的无符号整数。
#if 1
#pragma unroll
for (int i = 0; i < 4; i++) {
tensor(i, m, n) = encode_dropout(rnd_8[i] <= p_dropout_in_uint8_t, tensor(i, m, n));
}
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n));
// if (cute::thread0()) { printf("pos2: tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
#else
#endif
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// }
}
}
#endif
}
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout_continuous(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), tensor_.layout());
// Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
// if (thread0())
// {
// print("tensor_\n"); print(tensor_); print("\n");
// // print("tensor\n"); print(tensor); print("\n");
// }
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
if constexpr(encode_dropout_in_sign_bit) {
return keep ? val : -val;
} else {
return keep ? val : T(0);
}
};
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = block_col_start + (lane_id / 16) * 4;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 1;
for (int i = 0; i < size<1>(tensor); ++i)
{
const int row_idx_base = block_row_start + i * block_row_stride;
const int row_idx = row_idx_base;
for (int j = 0; j < size<2>(tensor); ++j)
{
const int col_idx_base = col_idx_offset + j * stride_between_each_repeat;
for (int mi = 0; mi < size<0>(tensor); ++mi)
{
const int col_idx = col_idx_base + mi * stride_between_each_thread;
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// printf("tidx = %d row_idx = %d col_idx = %d offset = %d\n", threadIdx.x, row_idx, col_idx, offset);
// }
uint2 rowcol = make_uint2(row_idx, col_idx);
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
tensor(mi, i, j) =
encode_dropout(rnd_8[0] <= p_dropout_in_uint8_t, tensor(mi, i, j));
}
}
}
// #pragma unroll
// for (int nj = 0; nj < size<1>(tensor); ++nj) {
// const int row_idx_base = block_row_start + mi * warp_row_stride;
// const int row_idx = row_idx_base;
// const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
// #pragma unroll
// for (int j = 0; j < size<2>(tensor); ++j) {
// const int col_idx = col_idx_base + j * stride_between_each_thread;
// #pragma unroll
// for (int mi = 0; mi < size<0>(tensor); ++mi)
// {
// uint2 rowcol = make_uint2(row_idx, col_idx);
// uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// tensor(mi, make_coord(j, nj)) =
// encode_dropout(rnd_8[0] <= p_dropout_in_uint8_t, tensor(mi, make_coord(j, nj)));
// }
// }
// }
// #if 1
// #pragma unroll
// for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
// uint2 rowcol = make_uint2(block_row_start, block_col_start);
// #pragma unroll
// for (int n = 0; n < size<2>(tensor); ++n, ++rowcol.y) {
// uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
// uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// // 16位类型的特殊实现:我们将阈值复制到32位值的低16位和高16位,然后使用f16x2比较指令来获取掩码。
// // 掩码的低16位将是0xffff或0x0000,高16位也将是0xffff或0x0000,这取决于随机值是否小于阈值。
// // 然后,我们在掩码和原始32位值之间进行位与运算。
// // 我们利用了浮点比较等同于整数比较的事实,因为我们比较的是其最高8位为零的无符号整数。
// #if 1
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// tensor(i, m, n) = encode_dropout(rnd_8[i] <= p_dropout_in_uint8_t, tensor(i, m, n));
// }
// Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n));
// // if (cute::thread0()) { printf("pos2: tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
// #else
// #endif
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
// }
// }
// #endif
}
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout_continuous_fp8(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), tensor_.layout());
// Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
// if (thread0())
// {
// print("tensor_\n"); print(tensor_); print("\n");
// // print("tensor\n"); print(tensor); print("\n");
// }
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
if constexpr(encode_dropout_in_sign_bit) {
return keep ? val : -val;
} else {
return keep ? val : T(0);
}
};
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = block_col_start + (lane_id / 16) * 8;
const int stride_between_each_repeat = 32;
const int stride_between_each_thread = 1;
for (int i = 0; i < size<1>(tensor); ++i)
{
const int row_idx_base = block_row_start + i * block_row_stride;
const int row_idx = row_idx_base;
for (int j = 0; j < size<2>(tensor); ++j)
{
const int col_idx_base = col_idx_offset + j * stride_between_each_repeat;
for (int mi = 0; mi < size<0>(tensor); ++mi)
{
const int col_idx = col_idx_base + mi * stride_between_each_thread;
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// printf("tidx = %d row_idx = %d col_idx = %d offset = %d\n", threadIdx.x, row_idx, col_idx, offset);
// }
uint2 rowcol = make_uint2(row_idx, col_idx);
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
tensor(mi, i, j) =
encode_dropout(rnd_8[0] <= p_dropout_in_uint8_t, tensor(mi, i, j));
}
}
}
}
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout_trans(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
if constexpr(encode_dropout_in_sign_bit) {
return keep ? val : -val;
} else {
return keep ? val : T(0);
}
};
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = block_col_start + (lane_id / 16) * 4;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 1;
for (int i = 0; i < size<1>(tensor); ++i)
{
const int row_idx_base = block_row_start + i * block_row_stride;
const int row_idx = row_idx_base;
for (int j = 0; j < size<2>(tensor); ++j)
{
const int col_idx_base = col_idx_offset + j * stride_between_each_repeat;
for (int mi = 0; mi < size<0>(tensor); ++mi)
{
const int col_idx = col_idx_base + mi * stride_between_each_thread;
uint2 rowcol = make_uint2(col_idx, row_idx);
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
tensor(mi, i, j) =
encode_dropout(rnd_8[0] <= p_dropout_in_uint8_t, tensor(mi, i, j));
}
}
}
// for (int m = 0; m < size<1>(tensor); ++m, block_col_start += block_col_stride) {
// }
// #if 1
// #pragma unroll
// for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
// // uint2 rowcol = make_uint2(block_row_start, block_col_start);
// uint2 colrow = make_uint2(block_col_start, block_row_start);
// #pragma unroll
// for (int n = 0; n < size<2>(tensor); ++n, ++colrow.y) {
// uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(colrow), offset);
// // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
// uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// // 16位类型的特殊实现:我们将阈值复制到32位值的低16位和高16位,然后使用f16x2比较指令来获取掩码。
// // 掩码的低16位将是0xffff或0x0000,高16位也将是0xffff或0x0000,这取决于随机值是否小于阈值。
// // 然后,我们在掩码和原始32位值之间进行位与运算。
// // 我们利用了浮点比较等同于整数比较的事实,因为我们比较的是其最高8位为零的无符号整数。
// #if 1
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// tensor(i, m, n) = encode_dropout(rnd_8[i] <= p_dropout_in_uint8_t, tensor(i, m, n));
// }
// // Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n));
// // if (cute::thread0()) { printf("pos2: tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
// #else
// #endif
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
// }
// }
// #endif
}
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout_continuous_opt(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), tensor_.layout());
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
if constexpr(encode_dropout_in_sign_bit) {
return keep ? val : -val;
} else {
return keep ? val : T(0);
}
};
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = block_col_start + (lane_id / 16) * 4;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 1;
for (int i = 0; i < size<1>(tensor); ++i)
{
const int row_idx_base = block_row_start + i * block_row_stride + (threadIdx.x / 64) * 16;
const int row_idx = row_idx_base;
uint2 rowcol = make_uint2(row_idx, col_idx_offset);
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
for (int j = 0; j < size<2>(tensor); ++j)
{
for (int mi = 0; mi < size<0>(tensor); ++mi)
{
tensor(mi, i, j) =
encode_dropout(rnd_8[j * 4 + mi] <= p_dropout_in_uint8_t, tensor(mi, i, j));
}
}
}
}
template <bool encode_dropout_in_sign_bit = false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout_trans_opt(
Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride)
{
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
if constexpr (encode_dropout_in_sign_bit) {
return keep ? val : -val;
} else {
return keep ? val : T(0);
}
};
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = block_col_start + (threadIdx.x / 64) * 16;
extern __shared__ char smem_[];
uint8_t *p_rand_8 = reinterpret_cast<uint8_t *>(smem_ + 16384);
// write
int row_ = (threadIdx.x % 16) + (threadIdx.x / 64) * 16;
int col_ = (lane_id / 16) * 16;
// read
const int read_row = (lane_id / 16) * 4;
const int lane_group = (lane_id % 16) / 4;
const int lane_offset = lane_id % 4;
const int read_col = (threadIdx.x / 64) * 4 + lane_group * 16 + lane_offset;
// padding stride
// constexpr int RAND_STRIDE = 64 + 4;
constexpr int RAND_STRIDE = 64;
for (int i = 0; i < size<1>(tensor); ++i) {
const int row_idx_base = block_row_start + i * block_row_stride + (lane_id / 16) * 4;
uint2 rowcol = make_uint2(col_idx_offset, row_idx_base);
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long &>(rowcol), offset);
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
*reinterpret_cast<uint4*>(&p_rand_8[row_ * RAND_STRIDE + col_]) = random_uint4;
__syncthreads();
#pragma unroll
for (int j = 0; j < size<2>(tensor); ++j) {
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int rand_read_row = read_row + j * 16 + mi;
const uint8_t t_rand = p_rand_8[(rand_read_row) * RAND_STRIDE + read_col];
tensor(mi, i, j) =
encode_dropout(t_rand <= p_dropout_in_uint8_t, tensor(mi, i, j));
}
}
}
}
};
} // namespace flash
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
// get environment variables for internal usage
static inline bool get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
if (strcmp(value, "0") == 0) {
return false;
}
return true;
}
return false;
}
static std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string raw_name(props.gcnArchName);
return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
}
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void * __restrict__ p_ptr;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
// For FP8 scaling
float * __restrict__ q_descale_ptr;
float * __restrict__ k_descale_ptr;
float * __restrict__ v_descale_ptr;
index_t q_descale_batch_stride;
index_t q_descale_head_stride;
index_t k_descale_batch_stride;
index_t k_descale_head_stride;
index_t v_descale_batch_stride;
index_t v_descale_head_stride;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int * __restrict__ leftpad_k;
int * __restrict__ padding_mask;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
int *__restrict__ blockmask;
// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;
// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
int * __restrict__ cache_batch_idx;
// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;
// Local window size
int window_size_left, window_size_right;
float softcap;
// Random state.
at::PhiloxCudaState philox_args;
// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;
bool is_bf16;
bool is_fp8;
bool is_e4m3;
bool is_causal;
bool is_vllm_kvcache;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
bool is_rotary_interleaved;
int num_splits; // For split-KV version
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
// Attention Sinks: precomputed LogSumExp for sink tokens
// Shape: [nheads], dtype: float32 (ElementAccum). Maximum 64 heads supported (shared memory limit).
// Used for streaming LLM inference to maintain attention to initial "sink" tokens.
void * __restrict__ s_aux_ptr;
int d_value, d_value_rounded;
float skip_softmax_threshold_scale_factor;
void * skip_blocks_info_ptr;
void * __restrict__ debug_ptr; // for debug
void * __restrict__ qq_bias_ptr;
int qq_bias_stride_0;
int * __restrict__ mm_prefix_range_ptr;
int max_mm_ranges = 0;
bool use_alibi_sqrt = false;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_bwd_params : public Flash_fwd_params {
// The dO and dQKV matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;
// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
bool deterministic;
index_t dq_accum_split_stride;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, int HeaddimV, bool Is_causal> void run_mha_fwd_mla_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T,typename TO, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t stream);
template<typename T,typename TO, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch_kv_fp8(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_unified_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_padding_mask_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_blasst_fwd_(Flash_fwd_params &params, cudaStream_t stream);
void run_mha_varlen_tiny_fwd_dim64(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, int HeaddimV, bool Is_causal> void run_mha_fwd_splitkv_mla_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, int HeaddimV, bool Is_causal> void run_mha_bwd_mla_(Flash_bwd_params &params, cudaStream_t stream);
#pragma once
#include "flash.h"
// Parameter struct for attention mask forward path, isolated from the main
// Flash_fwd_params to avoid polluting the common kernel interface.
struct Flash_fwd_params_attnmask : public Flash_fwd_params {
// Attention mask pointer and strides.
// Expected layout: [b, h, seqlen_q, seqlen_k], with K dim contiguous
// (K stride == 1). Only the Q stride is configurable here.
void * __restrict__ mask_ptr;
index_t mask_batch_stride;
index_t mask_head_stride;
index_t mask_seq_q_stride;
// Value to write when mask is false.
float masked_value;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Forward entry point for attention with explicit mask.
template<typename T, int Headdim, bool Is_causal>
void run_mha_fwd_attnmask_(Flash_fwd_params_attnmask &params, cudaStream_t stream);
////////////////////////////////////////////////////////////////////////////////////////////////////
// Parameter struct for attention mask backward path.
struct Flash_bwd_params_attnmask : public Flash_bwd_params {
// Attention mask pointer and strides.
// Expected layout: [b, h, seqlen_q, seqlen_k], with K dim contiguous
// (K stride == 1). Only the Q stride is configurable here.
void * __restrict__ mask_ptr;
index_t mask_batch_stride;
index_t mask_head_stride;
index_t mask_seq_q_stride;
// Value used when mask is false (typically -INFINITY in forward).
// In backward, positions where mask is false have P=0, so dS should also be 0.
float masked_value;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Backward entry point for attention with explicit mask.
template<typename T, int Headdim, bool Is_causal>
void run_mha_bwd_attnmask_(Flash_bwd_params_attnmask &params, cudaStream_t stream);
// Copyright (c) 2026, Attnmask extension.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_attnmask_launch_template.h"
template void run_mha_bwd_attnmask_<cutlass::bfloat16_t, 128, true>(
Flash_bwd_params_attnmask &params, cudaStream_t stream);
// Copyright (c) 2026, Attnmask extension.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_attnmask_launch_template.h"
template void run_mha_bwd_attnmask_<cutlass::bfloat16_t, 128, false>(
Flash_bwd_params_attnmask &params, cudaStream_t stream);
// Copyright (c) 2026, Attnmask extension.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_attnmask_launch_template.h"
template void run_mha_bwd_attnmask_<cutlass::half_t, 128, true>(
Flash_bwd_params_attnmask &params, cudaStream_t stream);
// Copyright (c) 2026, Attnmask extension.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_attnmask_launch_template.h"
template void run_mha_bwd_attnmask_<cutlass::half_t, 128, false>(
Flash_bwd_params_attnmask &params, cudaStream_t stream);
// Copyright (c) 2026, Attnmask extension.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_attnmask_launch_template.h"
template void run_mha_bwd_attnmask_<cutlass::bfloat16_t, 64, true>(
Flash_bwd_params_attnmask &params, cudaStream_t stream);
// Copyright (c) 2026, Attnmask extension.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_attnmask_launch_template.h"
template void run_mha_bwd_attnmask_<cutlass::bfloat16_t, 64, false>(
Flash_bwd_params_attnmask &params, cudaStream_t stream);
// Copyright (c) 2026, Attnmask extension.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_attnmask_launch_template.h"
template void run_mha_bwd_attnmask_<cutlass::half_t, 64, true>(
Flash_bwd_params_attnmask &params, cudaStream_t stream);
// Copyright (c) 2026, Attnmask extension.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_attnmask_launch_template.h"
template void run_mha_bwd_attnmask_<cutlass::half_t, 64, false>(
Flash_bwd_params_attnmask &params, cudaStream_t stream);
/******************************************************************************
* Copyright (c) 2026, Attnmask extension.
* Backward kernel for attention with explicit mask support.
*
* This file contains the backward kernels modified to support explicit attention masks.
* The key modification is applying the attention mask when recomputing S = QK^T,
* before the softmax (scale_apply_exp2), to ensure P = 0 at masked positions.
******************************************************************************/
#pragma once
#include "flash_bwd_kernel.h"
#include "flash_attnmask.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
// dQ computation with attention mask support (dim128 prefetch version)
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi,
bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Use_mask, typename Params>
inline __device__ void compute_dq_1rowblock_16x64_prefetch_attnmask(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
const int warpId = __builtin_amdgcn_readfirstlane(tidx / 64);
const int laneId = tidx % 64;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kStages = Kernel_traits::kStages;
using SdP_TiledShape_MNK = typename Kernel_traits::TiledMmaSdP::TiledShape_MNK;
constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(SdP_TiledShape_MNK{}))::value;;
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
// ============ Attention Mask initialization ============
bool* mask_ptr = Use_mask ? reinterpret_cast<bool*>(params.mask_ptr)
+ bidb * params.mask_batch_stride + bidh * params.mask_head_stride : nullptr;
// ======================================================
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if constexpr (Is_causal || Is_local) {
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
}
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + m_block * kBlockM;
const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
// ============ Attention Mask tensor setup ============
Tensor mM = make_tensor(mask_ptr,
make_shape(binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.mask_seq_q_stride, _1{}));
Tensor gM = local_tile(mM, Shape<Int<kBlockM>, Int<kBlockN>>{}, make_coord(m_block, _));
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutKGemm0{});
Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKGemm1transposed{});
Tensor sV = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutVGemm0{});
// S/dP
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
Tensor tSrQ = thr_mma_sdp.partition_fragment_A(gQ);
Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK);
Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(gdO);
Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV);
// dQ
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx);
Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKt);
auto gmem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto gmem_thr_copy_QdO = gmem_tiled_copy_QdO.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_QdO.partition_S(gQ);
Tensor tdPgdO = gmem_thr_copy_QdO.partition_S(gdO);
auto smem_tiled_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx);
typename Kernel_traits::TiledMma16x64BLayout tiled_mma_BLayout;
auto smem_tiled_copy_BLayout = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma_BLayout);
auto smem_thr_copy_BLayout = smem_tiled_copy_BLayout.get_thread_slice(tidx);
Tensor sVtemp = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutV{});
Tensor tdPsVBLayout = smem_thr_copy_BLayout.partition_S(sVtemp);
Tensor tdPsV = make_tensor(tdPsVBLayout.data(), convert_layout_B_rowcol<_64x32, kHeadDim/32>(tdPsVBLayout.layout()));
Tensor sKtemp = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutK{});
Tensor tSsKBLayout = smem_thr_copy_BLayout.partition_S(sKtemp);
Tensor tSsK = make_tensor(tSsKBLayout.data(), convert_layout_B_rowcol<_64x32, kHeadDim/32>(tSsKBLayout.layout()));
auto smem_tiled_copy_Kt = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16_WITH_8x64, Element>{}, tiled_mma_dq);
auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx);
Tensor tdQsKt8x64 = smem_thr_copy_Kt.partition_S(sKt);
Tensor tdQsKt = make_tensor(tdQsKt8x64.data(), convert_layout_B_rowcol<_16x128>(tdQsKt8x64.layout()));
// PREDICATES
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
Tensor cdO = make_identity_tensor(make_shape(size<0>(gdO), size<1>(gdO)));
Tensor tQcQ = gmem_thr_copy_QdO.partition_D(cQ);
Tensor tdOcdO = gmem_thr_copy_QdO.partition_D(cdO);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdPgdO)));
if constexpr (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tdOpdO); ++k) { tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d; }
}
// ============ Attention Mask partition ============
Tensor tSgM = thr_mma_sdp.partition_C(gM(_, _, n_block_max > 0 ? n_block_max - 1 : 0));
Tensor tSrM = make_fragment_like<uint8_t>(tSgM);
clear(tSrM);
// Identity tensor for mask predicates
// gM shape: [kBlockM, kBlockN], get<0> is Q direction, get<1> is K direction
// Prologue
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
clear(tdQrdQ);
Tensor cdQ = make_identity_tensor(make_shape(size<0>(gdQ), size<1>(gdQ)));
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
if constexpr(!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
}
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
);
return;
}
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QdO, tSgQ, tSrQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QdO, tdPgdO, tdPrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM);
Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor taccScS = thr_mma_sdp.partition_C(caccS);
Tensor taccScS_row = taccScS(0, _, 0);
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
}
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
int n_block = n_block_max - 1;
constexpr int k0_loops = size<2>(tSsK);
constexpr int k1_loops = size<2>(tdPsV);
constexpr int k2_loops = size<2>(tdQsKt);
static_assert(kStages <= k0_loops && kStages <= k1_loops && kStages <= k2_loops , "kStages is error");
#pragma unroll
for (int i = 0; i < kStages; i++) {
lds_direct_copy<Is_even_K, /*Is_even_MN=*/Is_even_MN>(gK, sK, i, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
}
flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{});
clear(acc_dq);
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
// ============ Pre-read mask for first iteration ============
if constexpr (Use_mask) {
if (n_block >= n_block_min) {
tSgM = thr_mma_sdp.partition_C(gM(_, _, n_block));
cute::copy(tSgM, tSrM);
}
}
#pragma unroll
for (; n_block >= n_block_min; --n_block) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
#pragma unroll
for (int i = 0; i < k0_loops - kStages; i++) {
lds_direct_copy<Is_even_K, /*Is_even_MN=*/Is_even_MN>(gK, sK, kStages + i, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, i);
asm volatile("s_barrier");
}
#pragma unroll
for (int i = 0; i < kStages; i++) {
lds_direct_copy<Is_even_K, /*Is_even_MN=*/Is_even_MN>(gV, sV, i, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, k0_loops - kStages + i);
asm volatile("s_barrier");
}
Tensor acc_s = make_tensor(acc_s_ori.data(), flash::convert_layout_acc(acc_s_ori.layout()));
if constexpr (Is_softcap) {
flash::apply_softcap(acc_s, params.softcap);
}
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
[[maybe_unused]] Tensor dtanh = make_tensor_like(scores);
if constexpr (Is_softcap) {
// Compute dtanh before masking to avoid -inf -> NaN in backward
flash::calculate_dtanh(scores, dtanh, params.softcap);
}
// ============ Apply attention mask ============
// Apply mask BEFORE alibi and causal masking, after softcap
flash::apply_atten_mask<Use_mask>(tSrM, acc_s_ori, params.masked_value);
#if 1
if constexpr (Has_alibi) {
const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64);
const int col_idx_offset = n_block * kBlockN + (warp_id / AtomLayoutMS) * MMA_N_SdP * 16;
const int row_idx_offset = m_block * kBlockM + get<0>(taccScS_row(0));
const int warp_row_stride = AtomLayoutMS * 16;
alibi.apply_alibi_continuous(scores, col_idx_offset, row_idx_offset, warp_row_stride);
}
#endif
#if 1
if constexpr (!Is_causal && !Is_local) {
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64);
const int col_idx_offset_ = n_block * kBlockN + (warp_id / AtomLayoutMS) * MMA_N_SdP * 16;
flash::apply_mask_continuous(scores, binfo.actual_seqlen_k, col_idx_offset_);
}
} else if constexpr (Is_causal) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64);
flash::apply_mask_causal_continuous(scores, n_block * kBlockN + (warp_id / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
AtomLayoutMS * 16);
}
} else if constexpr (Is_local) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64);
flash::apply_mask_local_continuous(scores, n_block * kBlockN + (warp_id / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, AtomLayoutMS * 16,
params.window_size_left, params.window_size_right);
}
}
#endif
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
#if 1
if constexpr (Is_dropout) {
const int wave_id = __builtin_amdgcn_readfirstlane(tidx / 64);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = m_block * kBlockM + row_idx_offset_in_block;
const int block_row_idx = row_idx_offset_;
const int block_col_idx = n_block * (kBlockN);
if constexpr (kHeadDim==128){
dropout.template apply_dropout_continuous_opt</*encode_dropout_in_sign_bit=*/true>(
acc_s, m_block * kBlockM, block_col_idx, AtomLayoutMS * 16
);
}else{
dropout.template apply_dropout_continuous</*encode_dropout_in_sign_bit=*/true>(
acc_s, block_row_idx, block_col_idx, AtomLayoutMS * 16
);
}
}
#endif
Tensor acc_dp_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_dp_ori);
#pragma unroll
for (int i = 0; i < k1_loops - kStages; i++) {
lds_direct_copy<Is_even_K, /*Is_even_MN=*/Is_even_MN>(gV, sV, kStages + i, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, i);
asm volatile("s_barrier");
}
#pragma unroll
for (int i = 0; i < kStages; i++) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gK, sKt, i, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, k1_loops - kStages + i);
asm volatile("s_barrier");
}
Tensor acc_dp = make_tensor(acc_dp_ori.data(), convert_layout_acc(acc_dp_ori.layout()));
Tensor dS = make_tensor(acc_dp.data(), scores.layout());
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#if 1
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
dS(mi, ni) = scaled_ds;
}
}
#endif
Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
Tensor tdQrdS = flash::convert_type<Element>(dS_reshaped);
#pragma unroll
for (int i = 0; i < k2_loops - kStages; i++) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gK, sKt, kStages + i, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
flash::gemm_k_rs(acc_dq, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt, i);
asm volatile("s_barrier");
}
if (n_block > n_block_min) {
gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
// ============ Pre-read mask for next iteration ============
if constexpr (Use_mask) {
tSgM = thr_mma_sdp.partition_C(gM(_, _, n_block - 1));
cute::copy(tSgM, tSrM);
}
#pragma unroll
for (int i = 0; i < kStages; i++) {
lds_direct_copy<Is_even_K>(gK, sK, i, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
flash::gemm_k_rs(acc_dq, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt, k2_loops - kStages + i);
asm volatile("s_barrier");
}
}
else if (kStages == 3){
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
flash::gemm_k_rs(acc_dq, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt, 1);
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
flash::gemm_k_rs(acc_dq, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt, 2);
asm volatile("s_barrier");
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
flash::gemm_k_rs(acc_dq, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt, 3);
asm volatile("s_barrier");
} else {
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
#pragma unroll
for (int i = 0; i < kStages; ++i) {
flash::gemm_k_rs(acc_dq, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt, k2_loops - kStages + i);
asm volatile("s_barrier");
}
}
}
// Epilogue: write dQ
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
using GmemCopyAtom = Copy_Atom<DefaultCopy, Element>;
auto gmem_tiled_copy_dQ = make_tiled_copy_C(GmemCopyAtom{}, tiled_mma_dq);
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor taccdQrdQ = gmem_thr_copy_dQ.retile_S(acc_dq);
Tensor taccdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Tensor cdQ = make_identity_tensor(make_shape(size<0>(gdQ), size<1>(gdQ)));
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
#pragma unroll
for (int m = 0; m < size<1>(taccdQrdQ); m++) {
if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
#pragma unroll
for (int k = 0; k < size<2>(taccdQrdQ); k++) {
const int col_id = get<1>(tdQcdQ(0, 0, k));
for (int i = 0; i < size<0>(taccdQrdQ); i++) {
if (Is_even_K || col_id + i * 4 < params.d) {
taccdQgdQ(i, m, k) = flash::convert_type<Element>(taccdQrdQ(i, m, k) * params.scale_softmax_rp_dropout);
}
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// dQ computation with attention mask support (dim64 version)
// Based on compute_dq_1rowblock_16x64_dim64_prefetch with mask logic inserted.
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi,
bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Use_mask, typename Params>
inline __device__ void compute_dq_1rowblock_16x64_dim64_prefetch_attnmask(
const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
extern __shared__ char smem_[];
const int tidx = threadIdx.x;
const int warpId = __builtin_amdgcn_readfirstlane(tidx / 64);
const int laneId = tidx % 64;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
using SdP_TiledShape_MNK = typename Kernel_traits::TiledMmaSdP::TiledShape_MNK;
constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(SdP_TiledShape_MNK{}))::value;
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
// ============ Attention Mask initialization ============
bool* mask_ptr = Use_mask ? reinterpret_cast<bool*>(params.mask_ptr)
+ bidb * params.mask_batch_stride + bidh * params.mask_head_stride : nullptr;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if constexpr (Is_causal || Is_local) {
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
}
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_lse = (params.unpadded_lse
? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)
: (bidb * params.h + bidh) * params.seqlen_q) + m_block * kBlockM;
const index_t row_offset_dpsum = (params.unpadded_lse
? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb
: (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
// ============ Attention Mask tensor setup ============
Tensor mM = make_tensor(mask_ptr,
make_shape(binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.mask_seq_q_stride, _1{}));
Tensor gM = local_tile(mM, Shape<Int<kBlockM>, Int<kBlockN>>{}, make_coord(m_block, _));
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutKGemm0{});
Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKGemm1transposed{});
Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKGemm1transposedNoSwizzle{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutVGemm0{});
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
Tensor tSrQ = thr_mma_sdp.partition_fragment_A(gQ);
Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK);
Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(gdO);
Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV);
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx);
Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle);
auto gmem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto gmem_thr_copy_QdO = gmem_tiled_copy_QdO.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_QdO.partition_S(gQ);
Tensor tdPgdO = gmem_thr_copy_QdO.partition_S(gdO);
auto smem_tiled_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx);
typename Kernel_traits::TiledMma16x64BLayout tiled_mma_BLayout;
auto smem_tiled_copy_BLayout = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma_BLayout);
auto smem_thr_copy_BLayout = smem_tiled_copy_BLayout.get_thread_slice(tidx);
Tensor sVtemp = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutV{});
Tensor tdPsVBLayout = smem_thr_copy_BLayout.partition_S(sVtemp);
Tensor tdPsV = make_tensor(tdPsVBLayout.data(), convert_layout_B_rowcol<_64x32, kHeadDim/32>(tdPsVBLayout.layout()));
Tensor sKtemp = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutK{});
Tensor tSsKBLayout = smem_thr_copy_BLayout.partition_S(sKtemp);
Tensor tSsK = make_tensor(tSsKBLayout.data(), convert_layout_B_rowcol<_64x32, kHeadDim/32>(tSsKBLayout.layout()));
auto smem_tiled_copy_Kt = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16_WITH_8x64, Element>{}, tiled_mma_dq);
auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx);
Tensor tdQsKt8x64 = smem_thr_copy_Kt.partition_S(sKt);
Tensor tdQsKt = make_tensor(tdQsKt8x64.data(), convert_layout_B_rowcol<_16x64_64>(tdQsKt8x64.layout()));
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
Tensor cdO = make_identity_tensor(make_shape(size<0>(gdO), size<1>(gdO)));
Tensor tQcQ = gmem_thr_copy_QdO.partition_D(cQ);
Tensor tdOcdO = gmem_thr_copy_QdO.partition_D(cdO);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdPgdO)));
if constexpr (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tdOpdO); ++k) { tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d; }
}
// ============ Attention Mask partition ============
Tensor tSgM = thr_mma_sdp.partition_C(gM(_, _, n_block_max > 0 ? n_block_max - 1 : 0));
Tensor tSrM = make_fragment_like<uint8_t>(tSgM);
clear(tSrM);
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
clear(tdQrdQ);
Tensor cdQ = make_identity_tensor(make_shape(size<0>(gdQ), size<1>(gdQ)));
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
if constexpr(!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
}
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
);
return;
}
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QdO, tSgQ, tSrQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QdO, tdPgdO, tdPrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM);
Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor taccScS = thr_mma_sdp.partition_C(caccS);
Tensor taccScS_row = taccScS(0, _, 0);
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
}
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
int n_block = n_block_max - 1;
constexpr int k0_loops = size<2>(tSsK);
constexpr int k1_loops = size<2>(tdPsV);
constexpr int k2_loops = size<2>(tdQsKt);
static_assert(k0_loops == 2 && k1_loops == 2 && k2_loops == 4 && kBlockN == 64, "kblockn should be 64");
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN>(gV, sV, 0, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN>(gV, sV, 1, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{});
clear(acc_dq);
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
// ============ Pre-read mask for first iteration ============
if constexpr (Use_mask) {
if (n_block >= n_block_min) {
tSgM = thr_mma_sdp.partition_C(gM(_, _, n_block));
cute::copy(tSgM, tSrM);
}
}
#pragma unroll
for (; n_block >= n_block_min; --n_block) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 0);
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 1);
asm volatile("s_barrier");
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gK, sKt, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gK, sKt, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gK, sKt, 2, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gK, sKt, 3, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_barrier");
Tensor acc_s = make_tensor(acc_s_ori.data(), flash::convert_layout_acc(acc_s_ori.layout()));
if constexpr (Is_softcap) {
flash::apply_softcap(acc_s, params.softcap);
}
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
[[maybe_unused]] Tensor dtanh = make_tensor_like(scores);
if constexpr (Is_softcap) {
flash::calculate_dtanh(scores, dtanh, params.softcap);
}
// ============ Apply attention mask ============
flash::apply_atten_mask<Use_mask>(tSrM, acc_s_ori, params.masked_value);
#if 1
if constexpr (Has_alibi) {
const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64);
const int col_idx_offset = n_block * kBlockN + (warp_id / AtomLayoutMS) * MMA_N_SdP * 16;
const int row_idx_offset = m_block * kBlockM + get<0>(taccScS_row(0));
const int warp_row_stride = AtomLayoutMS * 16;
alibi.apply_alibi_continuous(scores, col_idx_offset, row_idx_offset, warp_row_stride);
}
#endif
#if 1
if constexpr (!Is_causal && !Is_local) {
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64);
const int col_idx_offset_ = n_block * kBlockN + (warp_id / AtomLayoutMS) * MMA_N_SdP * 16;
flash::apply_mask_continuous(scores, binfo.actual_seqlen_k, col_idx_offset_);
}
} else if constexpr (Is_causal) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64);
flash::apply_mask_causal_continuous(scores, n_block * kBlockN + (warp_id / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
AtomLayoutMS * 16);
}
} else if constexpr (Is_local) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64);
flash::apply_mask_local_continuous(scores, n_block * kBlockN + (warp_id / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, AtomLayoutMS * 16,
params.window_size_left, params.window_size_right);
}
}
#endif
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
#if 1
if constexpr (Is_dropout) {
const int wave_id = __builtin_amdgcn_readfirstlane(tidx / 64);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = m_block * kBlockM + row_idx_offset_in_block;
const int block_row_idx = row_idx_offset_;
const int block_col_idx = n_block * (kBlockN);
dropout.template apply_dropout_continuous</*encode_dropout_in_sign_bit=*/true>(
acc_s, block_row_idx, block_col_idx, AtomLayoutMS * 16
);
}
#endif
Tensor acc_dp_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_dp_ori);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 0);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 1);
asm volatile("s_barrier");
Tensor acc_dp = make_tensor(acc_dp_ori.data(), convert_layout_acc(acc_dp_ori.layout()));
Tensor dS = make_tensor(acc_dp.data(), scores.layout());
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
dS(mi, ni) = scaled_ds;
}
}
Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
Tensor tdQrdS = flash::convert_type<Element>(dS_reshaped);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<0>(acc_dq, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<1>(acc_dq, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<2>(acc_dq, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<3>(acc_dq, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
asm volatile("s_barrier");
if (n_block > n_block_min) {
gV.data() = gV.data() + (-int(kBlockN * params.k_row_stride));
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
// ============ Pre-read mask for next iteration ============
if constexpr (Use_mask) {
tSgM = thr_mma_sdp.partition_C(gM(_, _, n_block - 1));
cute::copy(tSgM, tSrM);
}
lds_direct_copy<Is_even_K, true>(gK, sK, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true>(gK, sK, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true>(gV, sV, 0, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true>(gV, sV, 1, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
}
}
// Epilogue: write dQ
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
using GmemCopyAtom = Copy_Atom<DefaultCopy, Element>;
auto gmem_tiled_copy_dQ = make_tiled_copy_C(GmemCopyAtom{}, tiled_mma_dq);
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor taccdQrdQ = gmem_thr_copy_dQ.retile_S(acc_dq);
Tensor taccdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Tensor cdQ = make_identity_tensor(make_shape(size<0>(gdQ), size<1>(gdQ)));
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
#pragma unroll
for (int m = 0; m < size<1>(taccdQrdQ); m++) {
if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
#pragma unroll
for (int k = 0; k < size<2>(taccdQrdQ); k++) {
const int col_id = get<1>(tdQcdQ(0, 0, k));
for (int i = 0; i < size<0>(taccdQrdQ); i++) {
if (Is_even_K || col_id + i * 4 < params.d) {
taccdQgdQ(i, m, k) = flash::convert_type<Element>(taccdQrdQ(i, m, k) * params.scale_softmax_rp_dropout);
}
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// dQ wrapper with attention mask support (dispatches by kHeadDim)
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi,
bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Use_mask, typename Params>
inline __device__ void compute_dq_seqq_parallel_16x64_prefetch_attnmask(const Params &params) {
const int bidb = blockIdx.z;
const int bidh = blockIdx.y;
int m_block = blockIdx.x;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
if constexpr (kHeadDim == 128) {
compute_dq_1rowblock_16x64_prefetch_attnmask<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi,
Is_even_MN, Is_even_K, Is_softcap, Use_mask>(params, bidb, bidh, m_block);
#ifndef NO_CAUSAL_OPT
if constexpr (Is_causal) {
const int num_blocks = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
if (num_blocks - m_block - 1 != m_block) {
compute_dq_1rowblock_16x64_prefetch_attnmask<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi,
Is_even_MN, Is_even_K, Is_softcap, Use_mask>(params, bidb, bidh, num_blocks - m_block - 1);
}
}
#endif
} else if constexpr (kHeadDim == 64) {
compute_dq_1rowblock_16x64_dim64_prefetch_attnmask<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi,
Is_even_MN, Is_even_K, Is_softcap, Use_mask>(params, bidb, bidh, m_block);
#ifndef NO_CAUSAL_OPT
if constexpr (Is_causal) {
const int num_blocks = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
if (num_blocks - m_block - 1 != m_block) {
compute_dq_1rowblock_16x64_dim64_prefetch_attnmask<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi,
Is_even_MN, Is_even_K, Is_softcap, Use_mask>(params, bidb, bidh, num_blocks - m_block - 1);
}
}
#endif
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// dK/dV computation with attention mask support (dim128 prefetch version)
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi,
bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Use_mask, typename Params>
inline __device__ void compute_dk_dv_trans_1colblock_16x64_prefetch_attnmask(const Params &params, const int bidb, const int bidh, const int n_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
extern __shared__ char smem_[];
const int tidx = threadIdx.x;
const int warpId = tidx / 64;
const int laneId = tidx % 64;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int kStages = Kernel_traits::kStages;
constexpr int kSmemOffset = Kernel_traits::kSmemOffset;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
// ============ Attention Mask initialization ============
bool* mask_ptr = Use_mask ? reinterpret_cast<bool*>(params.mask_ptr)
+ bidb * params.mask_batch_stride + bidh * params.mask_head_stride : nullptr;
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
if constexpr (Is_local) {
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));
}
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM;
const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
// ============ Attention Mask tensor setup (transposed for dK/dV) ============
// Original mask layout: [seqlen_q, seqlen_k] with mask[q,k] indicating if q attends to k
// In dK/dV: S = K @ Q^T has shape [kBlockN, kBlockM] i.e. [key, query]
// So we need transposed view: [seqlen_k, seqlen_q] to match S layout
// This way mask_transposed[k,q] = mask[q,k] aligns with S[k,q]
Tensor mM = make_tensor(mask_ptr,
make_shape(binfo.actual_seqlen_k, binfo.actual_seqlen_q),
make_stride(_1{}, params.mask_seq_q_stride));
// For dK/dV: fixed n_block (key block), varying m_block (query block)
// gM shape is [kBlockN, kBlockM] to match S = K @ Q^T layout
Tensor gM = local_tile(mM, Shape<Int<kBlockN>, Int<kBlockM>>{}, make_coord(n_block, _));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQGemm0{});
Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQGemm1transposed{});
Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQGemm1transposedNoSwizzle{});
Tensor sdO = make_tensor(sQ.data() + kSmemOffset, typename Kernel_traits::SmemLayoutdOGemm0{});
Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutdOGemm1transposed{});
Tensor sdOtNoSwizzle = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutdOGemm1transposedNoSwizzle{});
// S/dP
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
Tensor tSrK = thr_mma_sdp.partition_fragment_A(gK);
Tensor tSrQ = thr_mma_sdp.partition_fragment_B(sQ);
Tensor tdPrV = thr_mma_sdp.partition_fragment_A(gV);
Tensor tdPrdO = thr_mma_sdp.partition_fragment_B(sdO);
// dV/dK
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx);
Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtNoSwizzle);
Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle);
// Copy Atom retiling
auto gmem_tiled_copy_KV = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
Tensor tSgK = gmem_thr_copy_KV.partition_S(gK);
Tensor tdPgV = gmem_thr_copy_KV.partition_S(gV);
auto smem_tiled_copy_QdO = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma_sdp);
auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);
typename Kernel_traits::TiledMma16x64BLayout tiled_mma_BLayout;
auto smem_tiled_copy_BLayout = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma_BLayout);
auto smem_thr_copy_BLayout = smem_tiled_copy_BLayout.get_thread_slice(tidx);
Tensor sQtemp = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQ{});
Tensor tSsQBLayout = smem_thr_copy_BLayout.partition_S(sQtemp);
Tensor tSsQ = make_tensor(tSsQBLayout.data(), convert_layout_B_rowcol<_64x32, kHeadDim/32>(tSsQBLayout.layout()));
Tensor sdOtemp = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutdO{});
Tensor tdPsdOBLayout = smem_thr_copy_BLayout.partition_S(sdOtemp);
Tensor tdPsdO = make_tensor(tdPsdOBLayout.data(), convert_layout_B_rowcol<_64x32, kHeadDimV/32>(tdPsdOBLayout.layout()));
auto smem_tiled_copy_QdOt = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16_WITH_8x64, Element>{}, tiled_mma_dkv);
auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx);
Tensor tdVsdOt8x64 = smem_thr_copy_QdOt.partition_S(sdOt);
Tensor tdVsdOt = make_tensor(tdVsdOt8x64.data(), convert_layout_B_rowcol<_16x128>(tdVsdOt8x64.layout()));
Tensor tdKsQt8x64 = smem_thr_copy_QdOt.partition_S(sQt);
Tensor tdKsQt = make_tensor(tdKsQt8x64.data(), convert_layout_B_rowcol<_16x128>(tdKsQt8x64.layout()));
// PREDICATES
Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK)));
Tensor cV = make_identity_tensor(make_shape(size<0>(gV), size<1>(gV)));
Tensor tKcK = gmem_thr_copy_KV.partition_D(cK);
Tensor tVcV = gmem_thr_copy_KV.partition_D(cV);
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tSgK)));
Tensor tVpV = make_tensor<bool>(make_shape(size<2>(tdPgV)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tVcV(0, 0, k)) < params.d_value; }
}
int m_block = m_block_max - 1;
int m_block_min = (!Is_causal && !Is_local)
? 0
: std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);
// ============ Attention Mask partition for dK/dV ============
// gM is now transposed: shape [kBlockN, kBlockM, num_m_tiles] matching S = K @ Q^T
// gM(_, _, m_block) selects the m_block-th query tile, giving [kBlockN, kBlockM]
// get<0> is K direction, get<1> is Q direction
Tensor tSgM = thr_mma_sdp.partition_C(gM(_, _, m_block_max > 0 ? m_block_max - 1 : 0));
Tensor tSrM = make_fragment_like<uint8_t>(tSgM);
clear(tSrM);
// Identity tensor for mask predicates (transposed layout)
if ((Is_local || !Is_even_MN) && m_block < m_block_min) {
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.dv_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
clear(tdKrdK);
clear(tdVrdV);
Tensor cdK = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK)));
Tensor cdV = make_identity_tensor(make_shape(size<0>(gdV), size<1>(gdV)));
Tensor tdKcdK = gmem_thr_copy_dKV.partition_D(cdK);
Tensor tdVcdV = gmem_thr_copy_dKV.partition_D(cdV);
Tensor tdKpdK = make_tensor<bool>(make_shape(size<2>(tdKcdK)));
Tensor tdVpdV = make_tensor<bool>(make_shape(size<2>(tdVcdV)));
#pragma unroll
for (int k = 0; k < size(tdKpdK); ++k) { tdKpdK(k) = get<1>(tdKcdK(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tdVpdV); ++k) { tdVpdV(k) = get<1>(tdVcdV(0, 0, k)) < params.d_value; }
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKcdK, tdKpdK, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdVcdV, tdVpdV, binfo.actual_seqlen_k - n_block * kBlockN
);
return;
}
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_KV, tSgK, tSrK, tKcK, tKpK, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_KV, tdPgV, tdPrV, tVcV, tVpV, binfo.actual_seqlen_k - n_block * kBlockN
);
Tensor caccS = make_identity_tensor(Shape<Int<kBlockN>, Int<kBlockM>>{});
Tensor taccScS = thr_mma_sdp.partition_C(caccS);
flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDimV>>{});
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{});
clear(acc_dv);
clear(acc_dk);
Tensor taccScS_row = taccScS(_, 0, _);
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = (laneId / 16) * 4 + (mi % 4) + (mi / 4) * 16;
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
}
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
constexpr int kS_loops = size<2>(tSsQ);
constexpr int kdV_loops = size<2>(tdVsdOt);
constexpr int kdP_loops = size<2>(tdPsdO);
constexpr int kdK_loops = size<2>(tdKsQt);
static_assert(kStages <= kS_loops && kStages <= kdV_loops && kStages <= kdP_loops && kStages <= kdK_loops, "kStages is error");
#pragma unroll
for (int i = 0; i < kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gQ, sQ, i, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
}
// ============ Pre-read mask for first iteration ============
if constexpr (Use_mask) {
if (m_block >= m_block_min) {
tSgM = thr_mma_sdp.partition_C(gM(_, _, m_block));
cute::copy(tSgM, tSrM);
}
}
#pragma unroll
for (; m_block >= m_block_min; m_block--) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockN>, Int<kBlockM>>{});
clear(acc_s_ori);
#pragma unroll
for (int i = 0; i < kS_loops - kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gQ, sQ, kStages + i, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, i);
S_BARRIER;
}
#pragma unroll
for (int i = 0; i < kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gdO, sdOt, i, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, kS_loops - kStages + i);
S_BARRIER;
}
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
if constexpr (Is_softcap) {
flash::apply_softcap(acc_s, params.softcap);
}
Tensor scores_trans = make_tensor(acc_s.data(), flash::convert_trans_layout_acc_rowcol(acc_s.layout()));
[[maybe_unused]] Tensor dtanh_trans = make_tensor_like(scores_trans);
if constexpr (Is_softcap) {
// Compute dtanh before masking to avoid -inf -> NaN in backward
flash::calculate_dtanh(scores_trans, dtanh_trans, params.softcap);
}
// ============ Apply attention mask (transposed for dK/dV) ============
// For dK/dV, S has shape [kBlockN, kBlockM] (transposed)
// Apply mask AFTER softcap to ensure masked positions stay at -inf
flash::apply_atten_mask<Use_mask>(tSrM, acc_s_ori, params.masked_value);
#if 1
if constexpr (Has_alibi) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int wave_id = tidx / 64;
const int col_idx_offset = m_block * kBlockM;
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = n_block * kBlockN + row_idx_offset_in_block;
alibi.apply_alibi_trans(scores, col_idx_offset, row_idx_offset_, kNWarps * 16);
}
#endif
#if 1
if constexpr(!Is_causal && !Is_local) {
if (!Is_even_MN && (m_block + 1) * kBlockM >= binfo.actual_seqlen_q) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int warp_id = tidx / 64;
const int col_idx_offset_ = m_block * kBlockM;
flash::apply_mask_trans(scores, binfo.actual_seqlen_q, col_idx_offset_);
}
} else if constexpr(Is_causal) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = n_block * kBlockN + row_idx_offset_in_block;
flash::apply_mask_causal_trans(
scores,
m_block * kBlockM,
binfo.actual_seqlen_k,
row_idx_offset_,
binfo.actual_seqlen_q,
kNWarps * 16
);
}
} else if constexpr(Is_local) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = n_block * kBlockN + row_idx_offset_in_block;
flash::apply_mask_local_trans(
scores,
m_block * kBlockM,
binfo.actual_seqlen_k,
row_idx_offset_,
binfo.actual_seqlen_q,
kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
}
#endif
flash::scale_apply_exp2</*scale_max=*/false>(scores_trans, lse, params.scale_softmax_log2);
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = (laneId / 16) * 4 + (mi % 4) + (mi / 4) * 16;
dP_sum(mi) = gdPsum(row);
}
if (m_block > m_block_min) {
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
gLSE.data() = gLSE.data() + (-int(kBlockM));
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = (laneId / 16) * 4 + (mi % 4) + (mi / 4) * 16;
lse(mi) = gLSE(row);
}
// ============ Pre-read mask for next iteration ============
if constexpr (Use_mask) {
tSgM = thr_mma_sdp.partition_C(gM(_, _, m_block - 1));
cute::copy(tSgM, tSrM);
}
}
if constexpr (Is_dropout) {
const int warp_id = tidx / 64;
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = (kHeadDim == 128) ? (n_block * kBlockN) : (n_block * kBlockN + row_idx_offset_in_block);
int block_row_idx = row_idx_offset_;
int block_col_idx = m_block * kBlockM;
if constexpr (kHeadDim==128){
dropout.template apply_dropout_trans_opt</*encode_dropout_in_sign_bit=*/true>(
acc_s, n_block * kBlockN, m_block * kBlockM, kNWarps * 16
);
}else{
dropout.template apply_dropout_trans</*encode_dropout_in_sign_bit=*/true>(
acc_s, block_row_idx, block_col_idx, kNWarps * 16
);
}
}
Tensor rP = !Is_dropout
? flash::convert_type<Element>(acc_s)
: flash::convert_type_relu<Element>(acc_s);
#pragma unroll
for (int i = 0; i < kdV_loops - kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gdO, sdOt, kStages + i, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
S_WAITCNT;
flash::gemm_k_rs(acc_dv, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt, i);
S_BARRIER;
}
#pragma unroll
for (int i = 0; i < kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, i, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
S_WAITCNT;
flash::gemm_k_rs(acc_dv, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt, kdV_loops - kStages + i);
S_BARRIER;
}
Tensor acc_dp_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockN>, Int<kBlockM>>{});
clear(acc_dp_ori);
#pragma unroll
for (int i = 0; i < kdP_loops - kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, kStages + i, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
S_WAITCNT;
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, i);
S_BARRIER;
}
#pragma unroll
for (int i = 0; i < kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gQ, sQt, i, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
S_WAITCNT;
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, kdP_loops - kStages + i);
S_BARRIER;
}
Tensor acc_dp = make_tensor(acc_dp_ori.data(), convert_layout_acc(acc_dp_ori.layout()));
Tensor dS = make_tensor(acc_dp.data(), scores_trans.layout());
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
float scaled_ds = pointwise_mult(scores_trans(mi, ni), dS(mi, ni), dP_sum(mi));
if constexpr (Is_softcap) { scaled_ds *= dtanh_trans(mi, ni); }
dS(mi, ni) = scaled_ds;
}
}
Tensor tdKrdSt = flash::convert_type<Element>(acc_dp);
#pragma unroll
for (int i = 0; i < kdK_loops - kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gQ, sQt, kStages + i, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
S_WAITCNT;
flash::gemm_k_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt, i);
S_BARRIER;
}
S_WAITCNT2;
flash::gemm_k_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt, 1);
S_BARRIER;
S_WAITCNT1;
flash::gemm_k_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt, 2);
S_BARRIER;
S_WAITCNT0;
flash::gemm_k_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt, 3);
S_BARRIER;
if (m_block > m_block_min) {
gQ.data() = gQ.data() + (-int(kBlockM * params.q_row_stride));
gdO.data() = gdO.data() + (-int(kBlockM * params.do_row_stride));
lds_direct_copy<Is_even_K>(gQ, sQ, 0, params.q_row_stride, params.d);
lds_direct_copy<Is_even_K>(gQ, sQ, 1, params.q_row_stride, params.d);
lds_direct_copy<Is_even_K>(gQ, sQ, 2, params.q_row_stride, params.d);
}
}
// Epilogue: write dK and dV
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.dv_row_stride, _1{}));
int row, col;
if constexpr (size<1>(acc_dk) == size<1>(acc_dv) && size<2>(acc_dk) == size<2>(acc_dv)) {
#pragma unroll
for (int mi = 0; mi < size<1>(acc_dk); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_k - n_block * kBlockN) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_dk); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_dk); ++ei) {
if (Is_even_K || col < params.d) {
gdK(row, col) = flash::convert_type<Element>(acc_dk(ei, mi, ni) * params.scale_softmax_rp_dropout);
gdV(row, col) = flash::convert_type<Element>(!Is_dropout ? acc_dv(ei, mi, ni) : acc_dv(ei, mi, ni) * params.rp_dropout );
}
col += 4;
}
}
}
}
} else {
#pragma unroll
for (int mi = 0; mi < size<1>(acc_dk); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_k - n_block * kBlockN) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_dk); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_dk); ++ei) {
if (Is_even_K || col < params.d) {
gdK(row, col) = flash::convert_type<Element>(acc_dk(ei, mi, ni) * params.scale_softmax_rp_dropout);
}
col += 4;
}
}
}
}
#pragma unroll
for (int mi = 0; mi < size<1>(acc_dv); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_k - n_block * kBlockN) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_dv); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_dv); ++ei) {
if (Is_even_K || col < params.d) {
gdV(row, col) = flash::convert_type<Element>(!Is_dropout ? acc_dv(ei, mi, ni) : acc_dv(ei, mi, ni) * params.rp_dropout);
}
col += 4;
}
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// dK/dV computation with attention mask support (dim64 version)
// Based on compute_dk_dv_trans_1colblock_16x64_dim64_prefetch with mask logic inserted.
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi,
bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Use_mask, typename Params>
inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim64_prefetch_attnmask(
const Params &params, const int bidb, const int bidh, const int n_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
extern __shared__ char smem_[];
const int tidx = threadIdx.x;
const int warpId = tidx / 64;
const int laneId = tidx % 64;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
// ============ Attention Mask initialization ============
bool* mask_ptr = Use_mask ? reinterpret_cast<bool*>(params.mask_ptr)
+ bidb * params.mask_batch_stride + bidh * params.mask_head_stride : nullptr;
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
if constexpr (Is_local) {
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));
}
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_lse = (params.unpadded_lse
? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)
: (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM;
const index_t row_offset_dpsum = (params.unpadded_lse
? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb
: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
// ============ Attention Mask tensor setup (transposed for dK/dV) ============
Tensor mM = make_tensor(mask_ptr,
make_shape(binfo.actual_seqlen_k, binfo.actual_seqlen_q),
make_stride(_1{}, params.mask_seq_q_stride));
Tensor gM = local_tile(mM, Shape<Int<kBlockN>, Int<kBlockM>>{}, make_coord(n_block, _));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQGemm0{});
Tensor sQt = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutQGemm1transposed{});
Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQGemm1transposedNoSwizzle{});
Tensor sdO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutdOGemm0{});
Tensor sdOt = make_tensor(sdO.data() + size(sQ), typename Kernel_traits::SmemLayoutdOGemm1transposed{});
Tensor sdOtNoSwizzle = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutdOGemm1transposedNoSwizzle{});
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
Tensor tSrK = thr_mma_sdp.partition_fragment_A(gK);
Tensor tSrQ = thr_mma_sdp.partition_fragment_B(sQ);
Tensor tdPrV = thr_mma_sdp.partition_fragment_A(gV);
Tensor tdPrdO = thr_mma_sdp.partition_fragment_B(sdO);
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx);
Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtNoSwizzle);
Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle);
auto gmem_tiled_copy_KV = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
Tensor tSgK = gmem_thr_copy_KV.partition_S(gK);
Tensor tdPgV = gmem_thr_copy_KV.partition_S(gV);
auto smem_tiled_copy_QdO = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma_sdp);
auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);
typename Kernel_traits::TiledMma16x64BLayout tiled_mma_BLayout;
auto smem_tiled_copy_BLayout = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma_BLayout);
auto smem_thr_copy_BLayout = smem_tiled_copy_BLayout.get_thread_slice(tidx);
Tensor sQtemp = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQ{});
Tensor tSsQBLayout = smem_thr_copy_BLayout.partition_S(sQtemp);
Tensor tSsQ = make_tensor(tSsQBLayout.data(), convert_layout_B_rowcol<_64x32, kHeadDim/32>(tSsQBLayout.layout()));
Tensor sdOtemp = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutdO{});
Tensor tdPsdOBLayout = smem_thr_copy_BLayout.partition_S(sdOtemp);
Tensor tdPsdO = make_tensor(tdPsdOBLayout.data(), convert_layout_B_rowcol<_64x32, kHeadDimV/32>(tdPsdOBLayout.layout()));
auto smem_tiled_copy_QdOt = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16_WITH_8x64, Element>{}, tiled_mma_dkv);
auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx);
Tensor tdVsdOt8x64 = smem_thr_copy_QdOt.partition_S(sdOt);
Tensor tdVsdOt = make_tensor(tdVsdOt8x64.data(), convert_layout_B_rowcol<_16x64_64>(tdVsdOt8x64.layout()));
Tensor tdKsQt8x64 = smem_thr_copy_QdOt.partition_S(sQt);
Tensor tdKsQt = make_tensor(tdKsQt8x64.data(), convert_layout_B_rowcol<_16x64_64>(tdKsQt8x64.layout()));
Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK)));
Tensor cV = make_identity_tensor(make_shape(size<0>(gV), size<1>(gV)));
Tensor tKcK = gmem_thr_copy_KV.partition_D(cK);
Tensor tVcV = gmem_thr_copy_KV.partition_D(cV);
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tSgK)));
Tensor tVpV = make_tensor<bool>(make_shape(size<2>(tdPgV)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tVcV(0, 0, k)) < params.d; }
}
int m_block = m_block_max - 1;
int m_block_min = (!Is_causal && !Is_local)
? 0
: std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);
// ============ Attention Mask partition for dK/dV ============
Tensor tSgM = thr_mma_sdp.partition_C(gM(_, _, m_block_max > 0 ? m_block_max - 1 : 0));
Tensor tSrM = make_fragment_like<uint8_t>(tSgM);
clear(tSrM);
if ((Is_local || !Is_even_MN) && m_block < m_block_min) {
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.dv_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
clear(tdKrdK);
clear(tdVrdV);
Tensor cdK = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK)));
Tensor cdV = make_identity_tensor(make_shape(size<0>(gdV), size<1>(gdV)));
Tensor tdKcdK = gmem_thr_copy_dKV.partition_D(cdK);
Tensor tdVcdV = gmem_thr_copy_dKV.partition_D(cdV);
Tensor tdKpdK = make_tensor<bool>(make_shape(size<2>(tdKcdK)));
Tensor tdVpdV = make_tensor<bool>(make_shape(size<2>(tdVcdV)));
#pragma unroll
for (int k = 0; k < size(tdKpdK); ++k) { tdKpdK(k) = get<1>(tdKcdK(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tdVpdV); ++k) { tdVpdV(k) = get<1>(tdVcdV(0, 0, k)) < params.d; }
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKcdK, tdKpdK, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdVcdV, tdVpdV, binfo.actual_seqlen_k - n_block * kBlockN
);
return;
}
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_KV, tSgK, tSrK, tKcK, tKpK, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_KV, tdPgV, tdPrV, tVcV, tVpV, binfo.actual_seqlen_k - n_block * kBlockN
);
Tensor caccS = make_identity_tensor(Shape<Int<kBlockN>, Int<kBlockM>>{});
Tensor taccScS = thr_mma_sdp.partition_C(caccS);
flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDimV>>{});
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{});
clear(acc_dv);
clear(acc_dk);
Tensor taccScS_row = taccScS(_, 0, _);
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = (laneId / 16) * 4 + (mi % 4) + (mi / 4) * 16;
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
}
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
lds_direct_copy<Is_even_K, Is_even_MN>(gQ, sQ, 0, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN>(gQ, sQ, 1, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gdO, sdOt, 0, params.do_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gdO, sdOt, 1, params.do_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gdO, sdOt, 2, params.do_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gdO, sdOt, 3, params.do_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
// ============ Pre-read mask for first iteration ============
if constexpr (Use_mask) {
if (m_block >= m_block_min) {
tSgM = thr_mma_sdp.partition_C(gM(_, _, m_block));
cute::copy(tSgM, tSrM);
}
}
#pragma unroll
for (; m_block >= m_block_min; m_block--) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockN>, Int<kBlockM>>{});
clear(acc_s_ori);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 0);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 1);
asm volatile("s_barrier");
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
Tensor scores_trans = make_tensor(acc_s.data(), flash::convert_trans_layout_acc_rowcol(acc_s.layout()));
if constexpr (Is_softcap) {
flash::apply_softcap(acc_s, params.softcap);
}
[[maybe_unused]] Tensor dtanh_trans = make_tensor_like(scores_trans);
if constexpr (Is_softcap) {
flash::calculate_dtanh(scores_trans, dtanh_trans, params.softcap);
}
// ============ Apply attention mask (transposed for dK/dV) ============
flash::apply_atten_mask<Use_mask>(tSrM, acc_s_ori, params.masked_value);
#if 1
if constexpr (Has_alibi) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int wave_id = tidx / 64;
const int col_idx_offset = m_block * kBlockM;
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = n_block * kBlockN + row_idx_offset_in_block;
alibi.apply_alibi_trans(scores, col_idx_offset, row_idx_offset_, kNWarps * 16);
}
#endif
#if 1
if constexpr(!Is_causal && !Is_local) {
if (!Is_even_MN && (m_block + 1) * kBlockM >= binfo.actual_seqlen_q) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int col_idx_offset_ = m_block * kBlockM;
flash::apply_mask_trans(scores, binfo.actual_seqlen_q, col_idx_offset_);
}
} else if constexpr(Is_causal) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = n_block * kBlockN + row_idx_offset_in_block;
flash::apply_mask_causal_trans(
scores,
m_block * kBlockM,
binfo.actual_seqlen_k,
row_idx_offset_,
binfo.actual_seqlen_q,
kNWarps * 16
);
}
} else if constexpr(Is_local) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = n_block * kBlockN + row_idx_offset_in_block;
flash::apply_mask_local_trans(
scores,
m_block * kBlockM,
binfo.actual_seqlen_k,
row_idx_offset_,
binfo.actual_seqlen_q,
kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
}
#endif
flash::scale_apply_exp2</*scale_max=*/false>(scores_trans, lse, params.scale_softmax_log2);
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = (laneId / 16) * 4 + (mi % 4) + (mi / 4) * 16;
dP_sum(mi) = gdPsum(row);
}
if (m_block > m_block_min) {
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
gLSE.data() = gLSE.data() + (-int(kBlockM));
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = (laneId / 16) * 4 + (mi % 4) + (mi / 4) * 16;
lse(mi) = gLSE(row);
}
// ============ Pre-read mask for next iteration ============
if constexpr (Use_mask) {
tSgM = thr_mma_sdp.partition_C(gM(_, _, m_block - 1));
cute::copy(tSgM, tSrM);
}
}
if constexpr (Is_dropout) {
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = n_block * kBlockN + row_idx_offset_in_block;
int block_row_idx = row_idx_offset_;
int block_col_idx = m_block * kBlockM;
dropout.template apply_dropout_trans</*encode_dropout_in_sign_bit=*/true>(
acc_s, block_row_idx, block_col_idx, kNWarps * 16
);
}
Tensor rP = !Is_dropout
? flash::convert_type<Element>(acc_s)
: flash::convert_type_relu<Element>(acc_s);
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 0, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<0>(acc_dv, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<1>(acc_dv, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<2>(acc_dv, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<3>(acc_dv, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
asm volatile("s_barrier");
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gQ, sQt, 0, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gQ, sQt, 1, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gQ, sQt, 2, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gQ, sQt, 3, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
Tensor acc_dp_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockN>, Int<kBlockM>>{});
clear(acc_dp_ori);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 0);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 1);
asm volatile("s_barrier");
Tensor acc_dp = make_tensor(acc_dp_ori.data(), convert_layout_acc(acc_dp_ori.layout()));
Tensor dS = make_tensor(acc_dp.data(), scores_trans.layout());
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
float scaled_ds = pointwise_mult(scores_trans(mi, ni), dS(mi, ni), dP_sum(mi));
if constexpr (Is_softcap) { scaled_ds *= dtanh_trans(mi, ni); }
dS(mi, ni) = scaled_ds;
}
}
Tensor tdKrdSt = flash::convert_type<Element>(acc_dp);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<0>(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<1>(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<2>(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<3>(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
asm volatile("s_barrier");
if (m_block > m_block_min) {
gQ.data() = gQ.data() + (-int(kBlockM * params.q_row_stride));
gdO.data() = gdO.data() + (-int(kBlockM * params.do_row_stride));
lds_direct_copy<Is_even_K, true>(gQ, sQ, 0, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, true>(gQ, sQ, 1, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, true, _16x64_64>(gdO, sdOt, 0, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, true, _16x64_64>(gdO, sdOt, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, true, _16x64_64>(gdO, sdOt, 2, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, true, _16x64_64>(gdO, sdOt, 3, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
}
}
// Epilogue: write dK and dV
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.dv_row_stride, _1{}));
int row, col;
if constexpr (size<1>(acc_dk) == size<1>(acc_dv) && size<2>(acc_dk) == size<2>(acc_dv)) {
#pragma unroll
for (int mi = 0; mi < size<1>(acc_dk); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_k - n_block * kBlockN) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_dk); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_dk); ++ei) {
if (Is_even_K || col < params.d) {
gdK(row, col) = flash::convert_type<Element>(acc_dk(ei, mi, ni) * params.scale_softmax_rp_dropout);
gdV(row, col) = flash::convert_type<Element>(!Is_dropout ? acc_dv(ei, mi, ni) : acc_dv(ei, mi, ni) * params.rp_dropout );
}
col += 4;
}
}
}
}
} else {
#pragma unroll
for (int mi = 0; mi < size<1>(acc_dk); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_k - n_block * kBlockN) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_dk); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_dk); ++ei) {
if (Is_even_K || col < params.d) {
gdK(row, col) = flash::convert_type<Element>(acc_dk(ei, mi, ni) * params.scale_softmax_rp_dropout);
}
col += 4;
}
}
}
}
#pragma unroll
for (int mi = 0; mi < size<1>(acc_dv); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_k - n_block * kBlockN) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_dv); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_dv); ++ei) {
if (Is_even_K || col < params.d) {
gdV(row, col) = flash::convert_type<Element>(!Is_dropout ? acc_dv(ei, mi, ni) : acc_dv(ei, mi, ni) * params.rp_dropout);
}
col += 4;
}
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// dK/dV wrapper with attention mask support (dispatches by kHeadDim)
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi,
bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Use_mask, typename Params>
inline __device__ void compute_dk_dv_trans_16x64_prefetch_attnmask(const Params &params) {
const int bidb = blockIdx.z;
const int bidh = blockIdx.y;
const int n_block = blockIdx.x;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
if constexpr (kHeadDim == 128) {
compute_dk_dv_trans_1colblock_16x64_prefetch_attnmask<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi,
Is_even_MN, Is_even_K, Is_softcap, Use_mask>(params, bidb, bidh, n_block);
#ifndef NO_CAUSAL_OPT
if constexpr (Is_causal) {
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
if (num_n_block - n_block - 1 != num_n_block) {
compute_dk_dv_trans_1colblock_16x64_prefetch_attnmask<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi,
Is_even_MN, Is_even_K, Is_softcap, Use_mask>(params, bidb, bidh, num_n_block - n_block - 1);
}
}
#endif
} else if constexpr (kHeadDim == 64) {
compute_dk_dv_trans_1colblock_16x64_dim64_prefetch_attnmask<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi,
Is_even_MN, Is_even_K, Is_softcap, Use_mask>(params, bidb, bidh, n_block);
}
}
} // namespace flash
/******************************************************************************
* Copyright (c) 2026, Attnmask extension.
* Launch template for backward pass with explicit mask support.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "flash_bwd_launch_template.h"
#include "flash_bwd_attnmask_kernel.h"
#include "flash_attnmask.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
// Backward pass with explicit mask support.
//
// The mask is applied when recomputing S = QK^T in the backward pass,
// before the softmax (scale_apply_exp2), to ensure P = 0 at masked positions.
////////////////////////////////////////////////////////////////////////////////////////////////////
// Define kernel wrapper macros for attnmask backward
#define DEFINE_FLASH_BACKWARD_ATTNMASK_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params_attnmask params)
// dK/dV kernel with mask support
DEFINE_FLASH_BACKWARD_ATTNMASK_KERNEL(flash_bwd_dk_dv_attnmask_kernel,
bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local));
flash::compute_dk_dv_trans_16x64_prefetch_attnmask<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi,
Is_even_MN, Is_even_K, Is_softcap, /*Use_mask=*/true>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
// dQ kernel with mask support
DEFINE_FLASH_BACKWARD_ATTNMASK_KERNEL(flash_bwd_dq_attnmask_kernel,
bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local));
flash::compute_dq_seqq_parallel_16x64_prefetch_attnmask<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi,
Is_even_MN, Is_even_K, Is_softcap, /*Use_mask=*/true>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Launch function for backward pass with mask support (prefetch version for gfx936/gfx938)
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_attnmask_prefetch(Flash_bwd_params_attnmask &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
#ifdef NO_CAUSAL_OPT
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
#else
const int non_causal_num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
const int num_n_block = (Is_causal && Kernel_trans_traits::kHeadDim != 64) ? (non_causal_num_n_block + 1) >> 1 : non_causal_num_n_block;
const int non_causal_num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int num_m_block = Is_causal ? (non_causal_num_m_block + 1) >> 1 : non_causal_num_m_block;
#endif
dim3 grid_m(num_m_block, params.h, params.b);
dim3 grid_m_do((params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM, params.b, params.h);
dim3 grid_n(num_n_block, params.h, params.b);
// Preprocess: compute dO * O sum
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(
static_cast<Flash_bwd_params&>(params));
C10_CUDA_KERNEL_LAUNCH_CHECK();
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr
&& params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr
&& params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dropout = Kernel_trans_traits::kBlockM * Kernel_trans_traits::kBlockN;
constexpr int smem_size_dk_dv = Kernel_trans_traits::kSmemPrefetchSize;
constexpr int smem_size_dk_dv_total = (Kernel_trans_traits::kHeadDim == 128) ? (smem_size_dk_dv + smem_size_dropout) : smem_size_dk_dv;
constexpr int smem_size_dq = Kernel_traits::kSmemPrefetchSize;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] {
// Launch dK/dV kernel
auto kernel_dkdv = &flash_bwd_dk_dv_attnmask_kernel<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel_dkdv<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
// Launch dQ kernel
auto kernel_dq = &flash_bwd_dq_attnmask_kernel<
Kernel_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel_dq<<<grid_m, Kernel_traits::kNThreads, smem_size_dq, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Launch function for backward pass with mask support (seqk parallel version)
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_attnmask_seqk_parallel_trans(Flash_bwd_params_attnmask &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.h, params.b);
dim3 grid_m_do(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
dim3 grid_n(num_n_block, params.h, params.b);
// Preprocess: compute dO * O sum
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(
static_cast<Flash_bwd_params&>(params));
C10_CUDA_KERNEL_LAUNCH_CHECK();
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr
&& params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr
&& params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dq_dk_dv = Kernel_trans_traits::kSmemSizeTrans1colblock;
constexpr int smem_size_dq = Kernel_traits::kSmemSize1rowblock;
BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// Launch dK/dV kernel
auto kernel_dkdv = &flash_bwd_dk_dv_attnmask_kernel<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel_dkdv<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Launch dQ kernel
auto kernel_dq = &flash_bwd_dq_attnmask_kernel<
Kernel_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel_dq<<<grid_m, Kernel_traits::kNThreads, smem_size_dq, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Backward dispatch for attnmask - uses the new kernels with mask support
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, bool Is_causal>
void run_mha_bwd_attnmask_hdim128(Flash_bwd_params_attnmask &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
constexpr static int Headdim = 128;
// Attnmask backward kernel requires prefetch traits (uses kStages, SmemLayoutKGemm0, etc.)
// Unified to use prefetch traits for all devices
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? (Is_causal ? 64 : 128) : 128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 3>;
run_flash_bwd_attnmask_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
});
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Entry point for backward pass with mask, dispatched by element type and causal flag.
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// Backward dispatch for attnmask - hdim64
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, bool Is_causal>
void run_mha_bwd_attnmask_hdim64(Flash_bwd_params_attnmask &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
constexpr static int Headdim = 64;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/128, /*kNWarps_*/4, T, 3>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 3>;
run_flash_bwd_attnmask_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
});
#endif
}
template<typename T, int Headdim, bool Is_causal>
void run_mha_bwd_attnmask_(Flash_bwd_params_attnmask &params, cudaStream_t stream) {
static_assert(Headdim == 64 || Headdim == 128);
if constexpr (Headdim == 128) {
run_mha_bwd_attnmask_hdim128<T, Is_causal>(params, stream);
} else if constexpr (Headdim == 64) {
run_mha_bwd_attnmask_hdim64<T, Is_causal>(params, stream);
}
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t, false>(params, stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment