"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "241757298f7e5abb3e199229fde2513228de82b9"
Commit 5b838a8b authored by Tri Dao's avatar Tri Dao
Browse files

Apply dropout scaling to dQ and dK instead of to V (in bwd)

Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
parent a5559a0e
......@@ -107,6 +107,7 @@ void set_params_fprop(FMHA_fprop_params &params,
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f;
TORCH_CHECK(p_dropout < 1.f);
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
......@@ -719,4 +720,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)");
m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)");
}
}
\ No newline at end of file
......@@ -115,6 +115,7 @@ struct FMHA_fprop_params : public Qkv_params {
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_bmm1_rp_dropout;
// Scale factor of 1 / (1 - p_dropout), in half2.
uint32_t scale_dropout;
......
......@@ -13,13 +13,13 @@ namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int ROWS, int THREADS_PER_ROW, int M, typename Gmem_softmax_sum>
inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M],
inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], const float scale,
Gmem_softmax_sum gmem_softmax_d, int tidx) {
float sum[M];
fmha::SumOp<float> sum_op;
#pragma unroll
for (int mi = 0; mi < M; ++mi) {
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op);
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op) * scale;
}
const int dp_sum_row = tidx / THREADS_PER_ROW;
if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) {
......@@ -213,18 +213,18 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
gmem_do.commit(smem_do);
if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
);
}
// Instead of scaling dP by rp_dropout, we scale V instead
if (Is_dropout) {
const uint32_t scale_dropout = params.scale_dropout;
#pragma unroll
for(int it=0; it < Gmem_tile_v::LDGS; it++){
gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]);
}
}
// // Instead of scaling dP by rp_dropout, we scale V instead
// if (Is_dropout) {
// const uint32_t scale_dropout = params.scale_dropout;
// #pragma unroll
// for(int it=0; it < Gmem_tile_v::LDGS; it++){
// gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]);
// }
// }
gmem_v.commit(smem_v);
......@@ -518,7 +518,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
gmem_do.commit(smem_do);
if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
);
}
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
......@@ -569,7 +569,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
// }
for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) {
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
// dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
}
// Output the values.
gmem_dq.store(dq_out, 0);
......@@ -613,7 +614,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
// acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f);
acc_dk[mi][ni].mul_(params.scale_bmm1f);
// acc_dk[mi][ni].mul_(params.scale_bmm1f);
acc_dk[mi][ni].mul_(params.scale_bmm1_rp_dropout);
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
......@@ -692,4 +694,4 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
} // namespace fmha
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment