Commit d3e64409 authored by Tri Dao's avatar Tri Dao
Browse files

Implement bwd for head dim 128

parent 0d854692
......@@ -24,7 +24,7 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
FlashAttention currently supports:
1. Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080).
2. fp16.
3. Head dimensions 16, 32, 64.
3. Head dimensions 16, 32, 64, 128 (bwd requires A100).
Our tentative roadmap:
1. [Jun 2022] Make package pip-installable.
......@@ -32,7 +32,7 @@ Our tentative roadmap:
3. [Jun 2022] Refactor to use Cutlass.
4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
5. [Jun 2022] Support bf16.
6. [Jul 2022] Support head dimension 128.
6. ~~[Jul 2022] Support head dimension 128~~[Done].
7. [Jul 2022] Support SM70 GPUs (V100).
8. [Aug 2022] Fuse rotary embedding.
9. [Aug 2022] Support Attention linear bias (e.g. ALiBi).
......
......@@ -144,9 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int base_N = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256;
// int base_N = 256;
int seq_len = 512;
if( max_seq_len <= 128 ) {
seq_len = 128;
......@@ -162,18 +160,13 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
at::Tensor o_tmp;
if (loop) {
o_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat));
}
if (loop) { o_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat)); }
auto softmax_lse = torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
// auto softmax_lse = torch::full({batch_size, num_heads, seq_len}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
at::Tensor s;
if (return_softmax) {
s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
// s = torch::ones({ batch_size, num_heads, seq_len, seq_len }, opts) * 10000.0;
}
if (return_softmax) { s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts); }
if( zero_tensors ) {
ctx.zero_();
......@@ -228,7 +221,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total x num_heads x head_size
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
const at::Tensor &softmax_lse_, // b x h x s softmax logsumexp
const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop
const float softmax_scale,
......@@ -239,6 +232,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
TORCH_CHECK((dprops->major == 8 && dprops->minor >= 0) || is_sm75);
auto launch = &run_fmha_dgrad_fp16_sm80;
......@@ -269,8 +263,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
if (head_size == 128) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
TORCH_CHECK(is_sm80);
}
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int base_N = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256;
int seq_len = 512;
if( max_seq_len <= 128 ) {
......@@ -282,18 +278,14 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
}
bool loop = seq_len > base_N;
// It's possible the softmax_lse_ from the fwd has a different length since base_N could be different.
auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, seq_len)}).contiguous();
auto dqkv = torch::empty_like(qkv);
auto opts = qkv.options();
// auto softmax_lse =
// torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
auto softmax_d = torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
// softmax.zero_();
// torch::nn::init::ones_(softmax);
// torch::nn::init::ones_(dqkv);
at::Tensor dq_tmp;
if (loop) {
dq_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat));
}
if (loop) { dq_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat)); }
if( zero_tensors ) {
dqkv.zero_();
......@@ -324,7 +316,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// We're gonna reset the rng state in Python after this kernel, so the counter offset
// here doesn't matter at all. We just choose an arbitrary number;
// here doesn't matter at all. We just choose an arbitrary number.
int64_t counter_offset = 4;
if( is_dropout ) {
......
......@@ -847,6 +847,7 @@ struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
// uint32_t smem_read_og = this->smem_ + this->smem_read_offset_;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Prepare the offset.
......@@ -872,6 +873,9 @@ struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("BYTES_PER_MMA_PER_CTA=%d, ni = %d, smem_read diff = %d\n", BYTES_PER_MMA_PER_CTA, ni, ptr - smem_read_og);
// }
// Store those values in the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
......@@ -885,6 +889,8 @@ struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 8 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2));
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
......@@ -1100,8 +1106,8 @@ struct Smem_tile_o {
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
uint32_t smem_read = this->smem_read_ + imm;
// TD [2022-06-05] Ugly fix for d=128, maybe there's a better way.
if ((Cta_tile::N == 128) && (ii % 2 == 1)) {
// TD [2022-06-05] Ugly fix for d=128 in the forward pass, maybe there's a better way.
if ((Cta_tile::N == 128) && (ROWS_PER_LDS == 4) && (ii % 2 == 1)) {
smem_read ^= 8 * BYTES_PER_LDS;
}
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
......@@ -1232,16 +1238,17 @@ struct Smem_tile_mma {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
int write_col, write_row;
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_M == 8) || WARPS_N == 1);
if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
write_row = (tidx & 0x1c) / 4;
write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
write_col ^= (write_row & 0x07) * 4;
} else {
write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
write_col = (tidx & 0x03);
// write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))) * 4;
write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x07)))) * 4;
}
// TODO [TD] Only works for, D=16, D=32 or D=64
write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : 0x07))) * 4;
// write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
......@@ -1309,7 +1316,8 @@ struct Smem_tile_mma_transposed : public Base {
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07)));
// read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x0f))));
read_col ^= (read_row & 0x07);
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
......@@ -1357,7 +1365,9 @@ struct Smem_tile_mma_epilogue : public Base {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
const int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07)));
// read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07)));
static_assert(Base::BYTES_PER_ROW == 32 || Base::BYTES_PER_ROW == 64 || Base::BYTES_PER_ROW == 128 || Base::BYTES_PER_ROW == 256);
read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x07))));
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
......@@ -1402,6 +1412,9 @@ struct Smem_tile_mma_epilogue : public Base {
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
// size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_);
// }
fmha::sts(offset + 0 * BYTES_PER_ROW, x);
fmha::sts(offset + 8 * BYTES_PER_ROW, z);
offset ^= 4 * Base::BYTES_PER_STS;
......
......@@ -120,4 +120,8 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para
// }
// }
// }
// if (params.d == 128) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
}
\ No newline at end of file
......@@ -512,6 +512,17 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// float2 tmp0 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][0]));
// printf("frag_dot[0][0]=%.6f, %.6f\n", tmp0.x, tmp0.y);
// float2 tmp1 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][1]));
// printf("frag_dot[0][1]=%.6f, %.6f\n", tmp1.x, tmp1.y);
// }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("l = %d, acc_dv[0][0]=%.6f, %.6f\n", l, acc_dv[0][0].elt(2), acc_dv[0][0].elt(3));
// printf("l = %d, acc_dv[0][1]=%.6f, %.6f\n", l, acc_dv[0][1].elt(2), acc_dv[0][1].elt(3));
// }
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if(l < steps - 1) {
......@@ -577,7 +588,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// if (Is_dropout) {
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
// }
dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f);
for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) {
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
}
// Output the values.
gmem_dq.store(dq_out, 0);
// Move to the next part of the output.
......@@ -614,7 +627,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
// printf("l final, acc_dv[0][0]=%.6f, %.6f\n", acc_dv[0][0].elt(2), acc_dv[0][0].elt(3));
// printf("l final, acc_dv[0][1]=%.6f, %.6f\n", acc_dv[0][1].elt(2), acc_dv[0][1].elt(3));
// }
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
......
......@@ -126,7 +126,7 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 8 && dprops->minor >= 0 && !is_dropout) {
if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
......@@ -170,7 +170,7 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0 && !is_dropout) {
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
......
......@@ -382,8 +382,6 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Apply the mask.
softmax.apply_mask(mask);
// softmax.unpack_noscale_half_and_apply_mask(acc_p, mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
......@@ -408,7 +406,6 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
}
}
// __half2 p_max[Mma_tile_p::MMAS_M];
softmax.template reduce_max</*zero_init=*/Is_first>(p_max);
// if ((threadIdx.x == 0) && (l == 38)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment