Commit d380e87f authored by Tri Dao's avatar Tri Dao
Browse files

Don't use Smem_dp_sum in backward pass

To reduce smem usage for SM75
parent b17c6fe2
...@@ -15,15 +15,13 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params ...@@ -15,15 +15,13 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>; using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2); static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N); static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
static_assert(smem_size_dp_sum == 16 * 4 * 2);
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum; constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2;
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
bool is_causal = params.is_causal; bool is_causal = params.is_causal;
...@@ -41,6 +39,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params ...@@ -41,6 +39,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>); : (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
} }
// printf("N = %d, WARPS_N = %d, Smem size = %d\n", N, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
if( smem_size_dq_dk_dv >= 48 * 1024 ) { if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute( FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
...@@ -97,4 +96,28 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para ...@@ -97,4 +96,28 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>; using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} }
// if (params.d == 64) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else {
// if( params.s == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if( params.s >= 256 ) {
// if (dprops->major == 8 && dprops->minor == 0) {
// // Don't share smem for K & V, and don't keep V in registers
// // This speeds things up by 2-3% by avoiding register spills, but it
// // uses more shared memory, which is fine on A100 but not other GPUs.
// // For other GPUs, we keep V in registers.
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if (dprops->major == 8 && dprops->minor > 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
// }
// }
// }
} }
\ No newline at end of file
...@@ -12,16 +12,19 @@ namespace fmha { ...@@ -12,16 +12,19 @@ namespace fmha {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Smem_dp_sum, int M> template <int ROWS, int THREADS_PER_ROW, int M, typename Gmem_softmax_sum>
inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const uint4 (&o)[M], inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M],
Smem_dp_sum smem, const int buffer_idx) { Gmem_softmax_sum gmem_softmax_d, int tidx) {
float sum[M];
fmha::SumOp<float> sum_op;
#pragma unroll #pragma unroll
for (int mi = 0; mi < M; ++mi) { for (int mi = 0; mi < M; ++mi) {
sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi])); sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op);
}
const int dp_sum_row = tidx / THREADS_PER_ROW;
if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) {
gmem_softmax_d.store_row(reinterpret_cast<const uint32_t (&)[M]>(sum), dp_sum_row);
} }
static_assert(M == 1);
smem.store(sum[0], buffer_idx);
// smem.store(sum, buffer_idx);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -101,8 +104,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -101,8 +104,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
using Smem_dp_sum = typename Kernel_traits::Smem_dp_sum;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>; // using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>; using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>;
...@@ -208,26 +209,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -208,26 +209,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse)); gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_lse.move(); gmem_softmax_lse.move();
float dp_sum[Mma_tile_p::MMAS_M * 2];
if (!Is_first) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
}
float dp_sum_regs[Gmem_tile_do::LDGS];
Smem_dp_sum smem_dp_sum(reinterpret_cast<float *>(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE * 2]), tidx);
if (!Is_first) { __syncthreads(); } if (!Is_first) { __syncthreads(); }
// Commit the data for Q, dO, and V to shared memory. // Commit the data for Q, dO, and V to shared memory.
gmem_q.commit(gemm_q_k.smem_q); gmem_q.commit(gemm_q_k.smem_q);
gmem_do.commit(smem_do); gmem_do.commit(smem_do);
if (Is_first) { if (Is_first) {
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0); dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
const int dp_sum_row = tidx / Smem_dp_sum::THREADS_PER_ROW; gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx
if ((dp_sum_row < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) { );
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row);
}
gmem_softmax_d.move();
} }
// Instead of scaling dP by rp_dropout, we scale V instead // Instead of scaling dP by rp_dropout, we scale V instead
...@@ -266,6 +255,10 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -266,6 +255,10 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
} }
} }
float dp_sum[Mma_tile_p::MMAS_M * 2];
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
// Commit the data for V to shared memory if it has not been done already. // Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K. // Make sure we are done loading the fragments for K.
...@@ -357,21 +350,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -357,21 +350,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// __syncthreads(); // __syncthreads();
// } // }
// TD [2022-04-24]: if Is_first, then it's faster to set acc_dp to zero then subtract by
// dp_sum later. If !Is_first, then it's faster to set acc_dp to -dp_sum and don't subtract
// later. This is because loading dp_sum earlier uses more registers.
fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
if (Is_first) { #pragma unroll
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_dp); for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
} else {
#pragma unroll #pragma unroll
for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) { for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
#pragma unroll #pragma unroll
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) { for (int ii = 0; ii < 8; ++ii) {
#pragma unroll acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
for (int ii = 0; ii < 8; ++ii) {
acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
}
} }
} }
} }
...@@ -409,12 +395,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -409,12 +395,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
smem_kt.load(frag_kt[0], 0); smem_kt.load(frag_kt[0], 0);
if (Is_first) {
const int quad = (tidx % Cta_tile_p::THREADS_PER_WARP) / 4;
const int row[2] = {quad, quad + 8};
smem_dp_sum.load(dp_sum, row, l % 2);
}
// Trigger the load for the next dO values. // Trigger the load for the next dO values.
if( l < steps - 1) { if( l < steps - 1) {
smem_do.move_to_next_write_buffer(); smem_do.move_to_next_write_buffer();
...@@ -430,7 +410,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -430,7 +410,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax // // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero. // // will be zero.
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; } // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
if (Is_first) { softmax.subtract_dp_sum(dp_sum); }
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
softmax.pack(frag_dp); softmax.pack(frag_dp);
...@@ -547,21 +526,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -547,21 +526,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
if(l < steps - 1) { if(l < steps - 1) {
gmem_do.commit(smem_do); gmem_do.commit(smem_do);
if (Is_first) { if (Is_first) {
// dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum); dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
// smem_dp_sum.move_to_next_write_buffer(); gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2); );
const int dp_sum_row_1 = tidx / Smem_dp_sum::THREADS_PER_ROW;
if ((dp_sum_row_1 < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row_1);
}
gmem_softmax_d.move();
} }
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse)); gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_lse.move(); gmem_softmax_lse.move();
if (!Is_first) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
}
} }
typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
...@@ -591,6 +561,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -591,6 +561,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Make sure dQ is in shared memory. // Make sure dQ is in shared memory.
__syncthreads(); __syncthreads();
if (l < steps - 1) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
}
// Load from shared memory. // Load from shared memory.
smem_dq.template load</*zero_init=*/Is_first>(dq_out); smem_dq.template load</*zero_init=*/Is_first>(dq_out);
......
...@@ -120,10 +120,25 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l ...@@ -120,10 +120,25 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} }
// if (launch_params.params.d == 64) { // if (launch_params.params.d == 64) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>; // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>; // // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// if (launch_params.params.d == 64) {
// if( launch_params.params.s == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if( launch_params.params.s >= 256 ) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// } // }
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment