Commit 5b838a8b authored by Tri Dao's avatar Tri Dao
Browse files

Apply dropout scaling to dQ and dK instead of to V (in bwd)

Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
parent a5559a0e
...@@ -107,6 +107,7 @@ void set_params_fprop(FMHA_fprop_params &params, ...@@ -107,6 +107,7 @@ void set_params_fprop(FMHA_fprop_params &params,
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.rp_dropout = 1.f / params.p_dropout; params.rp_dropout = 1.f / params.p_dropout;
params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f;
TORCH_CHECK(p_dropout < 1.f); TORCH_CHECK(p_dropout < 1.f);
set_alpha(params.scale_dropout, params.rp_dropout, data_type); set_alpha(params.scale_dropout, params.rp_dropout, data_type);
......
...@@ -115,6 +115,7 @@ struct FMHA_fprop_params : public Qkv_params { ...@@ -115,6 +115,7 @@ struct FMHA_fprop_params : public Qkv_params {
// Scale factor of 1 / (1 - p_dropout). // Scale factor of 1 / (1 - p_dropout).
float rp_dropout; float rp_dropout;
float scale_bmm1_rp_dropout;
// Scale factor of 1 / (1 - p_dropout), in half2. // Scale factor of 1 / (1 - p_dropout), in half2.
uint32_t scale_dropout; uint32_t scale_dropout;
......
...@@ -13,13 +13,13 @@ namespace fmha { ...@@ -13,13 +13,13 @@ namespace fmha {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <int ROWS, int THREADS_PER_ROW, int M, typename Gmem_softmax_sum> template <int ROWS, int THREADS_PER_ROW, int M, typename Gmem_softmax_sum>
inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], const float scale,
Gmem_softmax_sum gmem_softmax_d, int tidx) { Gmem_softmax_sum gmem_softmax_d, int tidx) {
float sum[M]; float sum[M];
fmha::SumOp<float> sum_op; fmha::SumOp<float> sum_op;
#pragma unroll #pragma unroll
for (int mi = 0; mi < M; ++mi) { for (int mi = 0; mi < M; ++mi) {
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op); sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op) * scale;
} }
const int dp_sum_row = tidx / THREADS_PER_ROW; const int dp_sum_row = tidx / THREADS_PER_ROW;
if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) { if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) {
...@@ -213,18 +213,18 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -213,18 +213,18 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
gmem_do.commit(smem_do); gmem_do.commit(smem_do);
if (Is_first) { if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>( dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
); );
} }
// Instead of scaling dP by rp_dropout, we scale V instead // // Instead of scaling dP by rp_dropout, we scale V instead
if (Is_dropout) { // if (Is_dropout) {
const uint32_t scale_dropout = params.scale_dropout; // const uint32_t scale_dropout = params.scale_dropout;
#pragma unroll // #pragma unroll
for(int it=0; it < Gmem_tile_v::LDGS; it++){ // for(int it=0; it < Gmem_tile_v::LDGS; it++){
gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); // gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]);
} // }
} // }
gmem_v.commit(smem_v); gmem_v.commit(smem_v);
...@@ -518,7 +518,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -518,7 +518,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
gmem_do.commit(smem_do); gmem_do.commit(smem_do);
if (Is_first) { if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>( dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
); );
} }
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse)); gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
...@@ -569,7 +569,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -569,7 +569,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout); // dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
// } // }
for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) {
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); // dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
} }
// Output the values. // Output the values.
gmem_dq.store(dq_out, 0); gmem_dq.store(dq_out, 0);
...@@ -613,7 +614,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -613,7 +614,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) { for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
// acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f); // acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f);
acc_dk[mi][ni].mul_(params.scale_bmm1f); // acc_dk[mi][ni].mul_(params.scale_bmm1f);
acc_dk[mi][ni].mul_(params.scale_bmm1_rp_dropout);
} }
} }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment