Commit c2d566ff authored by ltqin's avatar ltqin
Browse files

fix scale

parent bfe2c1dc
......@@ -285,7 +285,6 @@ int run(int argc, char* argv[])
const unsigned long long seed = 1;
const unsigned long long offset = 0;
float scale_rp_dropout = alpha * rp_dropout;
if(argc == 1)
{
......@@ -539,7 +538,7 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{scale_rp_dropout}, // dQ *= scale_rp_dropout
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
......
......@@ -84,7 +84,6 @@ __global__ void
const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
const float p_dropout,
const float rp_dropout,
const unsigned long long seed,
const unsigned long long offset)
{
......@@ -140,7 +139,6 @@ __global__ void
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout,
rp_dropout,
ph);
#else
ignore = p_a_grid;
......@@ -784,6 +782,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -960,7 +960,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.p_dropout_,
arg.rp_dropout_,
arg.seed_,
arg.offset_);
};
......
......@@ -95,6 +95,8 @@ struct Scale
y = scale_ * x;
};
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; }
float scale_;
};
......
......@@ -1171,9 +1171,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout,
FloatGemmAcc rp_dropout,
ck::philox& ph)
{
const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment