Commit 1aa6d7d9 authored by Tri Dao's avatar Tri Dao
Browse files

Rework dropout to decouple forward and backward

They don't have to have the same block size, number of threads, etc.
parent 1d0b41be
...@@ -236,7 +236,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -236,7 +236,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
int blocksize_c = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256; int blocksize_c = (head_size == 128 && (!is_sm80)) ? 128 : 256;
// Need to round max_seqlen_k to multiples of blocksize_c // Need to round max_seqlen_k to multiples of blocksize_c
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
if( max_seqlen_k_ <= 128 ) { if( max_seqlen_k_ <= 128 ) {
......
...@@ -63,7 +63,8 @@ struct Mask { ...@@ -63,7 +63,8 @@ struct Mask {
const bool col_valid = current_col < actual_seqlen_k; const bool col_valid = current_col < actual_seqlen_k;
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k; // const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k; //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 1)) {
// printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid); // printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid);
// } // }
return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid; return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
......
...@@ -1646,6 +1646,19 @@ struct Smem_tile_dp_sum { ...@@ -1646,6 +1646,19 @@ struct Smem_tile_dp_sum {
} }
} }
inline __device__ void store_pair(const float (&sum)[MMAS_M * 2]) {
float *smem_write = smem_;
// Extract the position in the warp.
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
int row = lane / 4;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0];
smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1];
}
}
inline __device__ void store_pair(const float (&sum)[MMAS_M * 2], const int buffer_idx) { inline __device__ void store_pair(const float (&sum)[MMAS_M * 2], const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS; float *smem_write = smem_ + buffer_idx * ROWS;
// Extract the position in the warp. // Extract the position in the warp.
......
...@@ -278,87 +278,55 @@ struct Softmax_base { ...@@ -278,87 +278,55 @@ struct Softmax_base {
// } // }
template <bool encode_dropout_in_sign_bit=false> template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout(Philox &ph, uint32_t p_dropout_in_uint) { inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) {
// We encode the dropout pattern in the sign bit of the non-negative // We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros // softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) { auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
}; };
#pragma unroll #pragma unroll
for( int mi = 0; mi < MMAS_M * 2; mi++ ) { for( int mi = 0; mi < MMAS_M; mi++ ) {
#pragma unroll #pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) { for( int ni = 0; ni < MMAS_N; ni++ ) {
uint4 tmp = ph(); uint16_t tmp[8];
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // fmha::uint4_to_ushort8(ph(), tmp);
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w); uint4 tmp_32 = ph();
// } fmha::uint4_to_ushort8(tmp_32, tmp);
elt_[mi][4 * ni + 0] = // if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]); // printf("tidx = %d, ni = %d, ph Philox: %u, %u, %u, %u\n", threadIdx.x, ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
elt_[mi][4 * ni + 1] =
encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]);
elt_[mi][4 * ni + 2] =
encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]);
elt_[mi][4 * ni + 3] =
encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]);
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout(Philox &ph0, Philox &ph1, uint32_t p_dropout_in_uint) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; mi++ ) {
static_assert(MMAS_N % 2 == 0);
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
uint4 tmp = ph0();
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph0, Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
elt_[mi][4 * ni + 0] =
encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]);
elt_[mi][4 * ni + 1] =
encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]);
elt_[mi][4 * ni + 2] =
encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]);
elt_[mi][4 * ni + 3] =
encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]);
tmp = ph1();
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph1, Philox: %u, %u, %u, %u\n", ni + 1, tmp.x, tmp.y, tmp.z, tmp.w);
// } // }
elt_[mi][4 * (ni + 1) + 0] = #pragma unroll
encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 0]); for (int ii = 0; ii < 2; ++ii) {
elt_[mi][4 * (ni + 1) + 1] = #pragma unroll
encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 1]); for (int jj = 0; jj < 4; ++jj) {
elt_[mi][4 * (ni + 1) + 2] = elt_[mi * 2 + ii][4 * ni + jj] =
encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 2]); encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
elt_[mi][4 * (ni + 1) + 3] = }
encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 3]); }
} }
} }
} }
template <bool encode_dropout_in_sign_bit=false> template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) { inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t,
unsigned long long philox_subsequence) {
// We encode the dropout pattern in the sign bit of the non-negative // We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros // softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) { auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
}; };
static_assert(MMAS_M == 1); // We're assuming 16x16 blocks.
#pragma unroll #pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) { for( int mi = 0; mi < MMAS_M; mi++ ) {
#pragma unroll #pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) { for( int ni = 0; ni < MMAS_N; ni++ ) {
uint16_t tmp[8]; uint16_t tmp[8];
fmha::uint4_to_ushort8(ph(), tmp); // fmha::uint4_to_ushort8(ph(), tmp);
fmha::uint4_to_ushort8(ph(philox_subsequence + ni * Cta_tile::WARPS_N), tmp);
// uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N);
// fmha::uint4_to_ushort8(tmp_32, tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w); // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
// } // }
#pragma unroll #pragma unroll
for (int ii = 0; ii < 2; ++ii) { for (int ii = 0; ii < 2; ++ii) {
......
...@@ -334,7 +334,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -334,7 +334,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
if (Is_dropout) { if (Is_dropout) {
// softmax.apply_dropout(ph, params.p_dropout_in_uint); // softmax.apply_dropout(ph, params.p_dropout_in_uint);
// softmax.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint); // softmax.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint);
softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t); // softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t);
unsigned int warp_idx = threadIdx.x / 32;
// TODO: this should change after we rearrange the warps (e.g. cutlass branch)
unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx;
unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx;
softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t, philox_subsequence);
} }
using Frag_p = fmha::Fragment_a<fmha::Row>; using Frag_p = fmha::Fragment_a<fmha::Row>;
...@@ -676,9 +681,8 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) { ...@@ -676,9 +681,8 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int tidx_global = (bidb * params.h + bidh) * blockDim.x + tidx;
auto seeds = at::cuda::philox::unpack(params.philox_args); auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
if (loop_steps == 1) { if (loop_steps == 1) {
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0); compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
......
...@@ -115,25 +115,15 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, ...@@ -115,25 +115,15 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k >= 256 ) { } else if( launch_params.params.seqlen_k >= 256 ) {
if (dprops->major == 8 && dprops->minor >= 0) { using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
if (launch_params.is_dropout) { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}
} }
} else if (launch_params.params.d == 128) { } else if (launch_params.params.d == 128) {
if( launch_params.params.seqlen_k == 128 ) { if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else { } else {
if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) { if (dprops->major == 8 && dprops->minor == 0) {
// TD [2022-06-05] Keep K in registers to reduce register spilling // TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128. // Gives about 6% speedup compared to using block size 128.
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
......
...@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){ ...@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph, const int loop_step_idx) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type; using elem_type = typename Kernel_traits::elem_type;
...@@ -470,9 +470,17 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -470,9 +470,17 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
constexpr bool encode_dropout_in_sign_bit = Return_softmax; constexpr bool encode_dropout_in_sign_bit = Return_softmax;
if (Is_dropout) { if (Is_dropout) {
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, params.p_dropout_in_uint); // softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint);
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint); // softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint);
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint16_t); // softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint16_t);
unsigned int warp_idx = threadIdx.x / 32;
// TODO: this should change after we rearrange the warps (e.g. cutlass branch)
unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx;
// We want to use actual_seqlen_k, not seqlen_k, since seqlen_k could be rounded
// differently in the fwd and bwd pass. E.g., for d=128 on A100, fwd rounds seqlen_k
// to multiples of 256 while bwd rounds seqlen_k to multiples of 128.
unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx;
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint16_t, philox_subsequence);
} }
using Frag_p = fmha::Fragment_a<fmha::Row>; using Frag_p = fmha::Fragment_a<fmha::Row>;
...@@ -650,23 +658,28 @@ inline __device__ void device_1xN_loop(const Params &params) { ...@@ -650,23 +658,28 @@ inline __device__ void device_1xN_loop(const Params &params) {
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx; // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
// them to have the same number of threads or have to traverse the attention matrix
// in the same order.
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
// (within a warp). We use the subsequence to store the location of the 16 x 16 blocks within
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
auto seeds = at::cuda::philox::unpack(params.philox_args); auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds));
constexpr int M = Kernel_traits::Cta_tile_p::M; constexpr int M = Kernel_traits::Cta_tile_p::M;
const int STEPS = (params.seqlen_q + M - 1) / M; const int STEPS = (params.seqlen_q + M - 1) / M;
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.seqlen_k == blocksize_c) { if (params.seqlen_k == blocksize_c) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph, 0);
} else { } else {
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph, loop_step_idx);
} }
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, 0, STEPS, ph, max_loop_steps - 1);
} }
} }
......
// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/multihead_attn/philox.cuh
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu // Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
#pragma once #pragma once
// Philox CUDA. // Philox CUDA.
...@@ -9,8 +10,7 @@ public: ...@@ -9,8 +10,7 @@ public:
__device__ inline Philox(unsigned long long seed, __device__ inline Philox(unsigned long long seed,
unsigned long long subsequence, unsigned long long subsequence,
unsigned long long offset) unsigned long long offset)
: STATE(0) : key(reinterpret_cast<const uint2&>(seed)) {
, key(reinterpret_cast<const uint2&>(seed)) {
//key.x = (unsigned int)seed; //key.x = (unsigned int)seed;
//key.y = (unsigned int)(seed >> 32); //key.y = (unsigned int)(seed >> 32);
//counter = make_uint4(0, 0, 0, 0); //counter = make_uint4(0, 0, 0, 0);
...@@ -19,7 +19,6 @@ public: ...@@ -19,7 +19,6 @@ public:
//STATE = 0; //STATE = 0;
//incr_n(offset / 4); //incr_n(offset / 4);
// key = reinterpret_cast<const uint2&>(seed);
ull2 * tmp = reinterpret_cast<ull2*>(&counter); ull2 * tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset / 4; tmp->x = offset / 4;
tmp->y = subsequence; tmp->y = subsequence;
...@@ -27,34 +26,46 @@ public: ...@@ -27,34 +26,46 @@ public:
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
// } // }
} }
__device__ inline uint4 operator()() { __device__ inline uint4 operator()() {
// if (STATE == 0) { uint4 counter_ = counter;
uint4 counter_ = counter; uint2 key_ = key;
uint2 key_ = key; // 7-round philox
// 7-round philox #pragma unroll
#pragma unroll for (int i = 0; i < 6; i++) {
for (int i = 0; i < 6; i++) { counter_ = single_round(counter_, key_);
counter_ = single_round(counter_, key_); key_.x += (kPhilox10A);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
key_.y += (kPhilox10B); }
} uint4 output = single_round(counter_, key_);
// output = single_round(counter_, key_); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
uint4 output = single_round(counter_, key_); // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); // }
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); incr();
// } return output;
incr(); }
__device__ inline uint4 operator()(const unsigned long long subsequence) {
uint4 counter_ = counter;
ull2 * tmp = reinterpret_cast<ull2*>(&counter_);
tmp->y = subsequence;
// if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("tidx = %d, counter_: %u, %u, %u, %u\n", threadIdx.x, counter_.x, counter_.y, counter_.z, counter_.w);
// }
uint2 key_ = key;
// 7-round philox
#pragma unroll
for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
}
uint4 output = single_round(counter_, key_);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// } // }
// return a float4 directly
// unsigned long ret;
// switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
// STATE = (STATE + 1) % 4;
return output; return output;
} }
...@@ -64,25 +75,23 @@ private: ...@@ -64,25 +75,23 @@ private:
uint64_t y; uint64_t y;
}; };
uint4 counter; uint4 counter;
// uint4 output;
const uint2 key; const uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ uint4 incr128 (uint4 ctr) // __device__ inline void incr_n(unsigned long long n) {
{ // unsigned int nlo = (unsigned int)(n);
// unsigned int nhi = (unsigned int)(n >> 32);
// counter.x += nlo;
// if (counter.x < nlo)
// nhi++;
// counter.y += nhi;
// if (nhi <= counter.y)
// return;
// if (++counter.z)
// return;
// ++counter.w;
// }
__device__ uint4 incr(uint4 ctr) {
uint4 res; uint4 res;
asm ("add.cc.u32 %0, %4, %8;\n\t" asm ("add.cc.u32 %0, %4, %8;\n\t"
"addc.cc.u32 %1, %5, %9;\n\t" "addc.cc.u32 %1, %5, %9;\n\t"
...@@ -98,42 +107,46 @@ private: ...@@ -98,42 +107,46 @@ private:
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// } // }
counter = incr128(counter); counter = incr(counter);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// } // }
} }
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) { // __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
*result_high = __umulhi(a, b); // unsigned int *result_high) {
return a * b; // *result_high = __umulhi(a, b);
} // return a * b;
__device__ uint2 mulhilo32_v2 (const unsigned int a, const unsigned int b) // }
{
__device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
uint2 *res; uint2 *res;
unsigned long long tmp; unsigned long long tmp;
asm ("mul.wide.u32 %0, %1, %2;\n\t" asm ("mul.wide.u32 %0, %1, %2;\n\t"
: "=l"(tmp) : "=l"(tmp)
: "r"(a), "r"(b)); : "r"(a), "r"(b));
res = (uint2*)(&tmp); res = (uint2*)(&tmp);
return *res; return *res;
} }
__device__ inline uint4 single_round(const uint4 ctr, const uint2 key) { __device__ inline uint4 single_round(const uint4 ctr, const uint2 key) {
//unsigned int hi0; //unsigned int hi0;
//unsigned int hi1; //unsigned int hi1;
//unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); //unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
//unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); //unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
//uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; //uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x); uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z); uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
return ret; return ret;
} }
static const unsigned long kPhilox10A = 0x9E3779B9; static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85; static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53; static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57; static const unsigned long kPhiloxSB = 0xCD9E8D57;
}; };
// Inverse of 2^32. // Inverse of 2^32.
constexpr float M_RAN_INVM32 = 2.3283064e-10f; constexpr float M_RAN_INVM32 = 2.3283064e-10f;
__device__ __inline__ float4 uniform4(const uint4 x) { __device__ __inline__ float4 uniform4(const uint4 x) {
......
...@@ -7,12 +7,10 @@ import flash_attn_cuda ...@@ -7,12 +7,10 @@ import flash_attn_cuda
def _get_block_size(device, head_dim, is_dropout): def _get_block_size(device, head_dim, is_dropout):
assert head_dim in [16, 32, 64, 128] assert head_dim in [16, 32, 64, 128]
if head_dim in [16, 32]: if head_dim in [16, 32, 64]:
return 256 return 256
elif head_dim == 64:
return 128 if (torch.cuda.get_device_capability(device) == (7, 5) and is_dropout) else 256
elif head_dim == 128: elif head_dim == 128:
return 256 if (torch.cuda.get_device_capability(device) == (8, 0) and not is_dropout) else 128 return 256 if (torch.cuda.get_device_capability(device) == (8, 0)) else 128
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
......
...@@ -621,7 +621,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): ...@@ -621,7 +621,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
@pytest.mark.parametrize('seqlen', [512]) @pytest.mark.parametrize('seqlen', [512])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_unpadded_qkvpacked_split(seqlen, d, dropout_p, causal, dtype): def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = 'cuda' device = 'cuda'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment