Commit c2d566ff authored by ltqin's avatar ltqin
Browse files

fix scale

parent bfe2c1dc
...@@ -285,7 +285,6 @@ int run(int argc, char* argv[]) ...@@ -285,7 +285,6 @@ int run(int argc, char* argv[])
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
float scale_rp_dropout = alpha * rp_dropout;
if(argc == 1) if(argc == 1)
{ {
...@@ -539,7 +538,7 @@ int run(int argc, char* argv[]) ...@@ -539,7 +538,7 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{scale_rp_dropout}, // dQ *= scale_rp_dropout Scale{alpha},
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
......
...@@ -84,7 +84,6 @@ __global__ void ...@@ -84,7 +84,6 @@ __global__ void
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
const float p_dropout, const float p_dropout,
const float rp_dropout,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
...@@ -140,7 +139,6 @@ __global__ void ...@@ -140,7 +139,6 @@ __global__ void
c0_matrix_mask, c0_matrix_mask,
p_dropout_in_16bits, p_dropout_in_16bits,
p_dropout, p_dropout,
rp_dropout,
ph); ph);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -784,6 +782,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -784,6 +782,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
rp_dropout_ = 1.f / p_dropout_; rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -960,7 +960,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -960,7 +960,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_, arg.p_dropout_in_16bits_,
arg.p_dropout_, arg.p_dropout_,
arg.rp_dropout_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_);
}; };
......
...@@ -95,6 +95,8 @@ struct Scale ...@@ -95,6 +95,8 @@ struct Scale
y = scale_ * x; y = scale_ * x;
}; };
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; }
float scale_; float scale_;
}; };
......
...@@ -1171,9 +1171,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1171,9 +1171,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout, FloatGemmAcc p_dropout,
FloatGemmAcc rp_dropout,
ck::philox& ph) ck::philox& ph)
{ {
const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize()); p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment