Commit 2e1f2c37 authored by ltqin's avatar ltqin
Browse files

fix datatype

parent 32b03f33
......@@ -1263,8 +1263,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const float p_drop,
ck::philox& ph)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
......
......@@ -908,7 +908,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(FloatGemmAcc);
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment