Unverified Commit 07bfa49a authored by guangzlu's avatar guangzlu Committed by GitHub
Browse files

update to mha develop (#922)



* uint8 dropout

* bias examples sync with uint8 dropout

* remove useless codes

* disable kloop stuff

---------
Co-authored-by: default avatardanyao12 <danyao12@amd.com>
parent e114d48b
...@@ -117,8 +117,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -117,8 +117,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr auto V_K0 = KPerBlock / V_K1 / V_K2; static constexpr auto V_K0 = KPerBlock / V_K1 / V_K2;
static constexpr auto V_N1 = NXdlPerWave; static constexpr auto V_N1 = NXdlPerWave;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2 static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time // get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16 static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -1487,8 +1487,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1487,8 +1487,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = const uint8_t p_dropout_in_uint8_t =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0)); __builtin_amdgcn_readfirstlane(uint8_t(std::floor(p_dropout * 255.0)));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() * const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout); rp_dropout);
...@@ -1806,7 +1806,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1806,7 +1806,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{ auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout}; p_dropout_in_uint8_t, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 = auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m); MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
...@@ -1856,7 +1856,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1856,7 +1856,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
n2)); // NPerXdl n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ushort, uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tensor_buffer; z_tensor_buffer;
...@@ -1866,7 +1866,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1866,7 +1866,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, uint8_t,
ZDataType, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
......
...@@ -130,8 +130,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -130,8 +130,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto V_K0 = Gemm1NPerBlock / KPerBlock; static constexpr auto V_K0 = Gemm1NPerBlock / KPerBlock;
static constexpr auto V_N1 = NXdlPerWave; static constexpr auto V_N1 = NXdlPerWave;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2 static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time // get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16 static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -1553,8 +1553,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1553,8 +1553,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = const uint8_t p_dropout_in_uint8_t =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0)); __builtin_amdgcn_readfirstlane(uint8_t(std::floor(p_dropout * 255.0)));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() * const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout); rp_dropout);
...@@ -1901,7 +1901,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1901,7 +1901,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{ auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout}; p_dropout_in_uint8_t, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 = auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m); MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
...@@ -1951,7 +1951,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1951,7 +1951,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
n2)); // NPerXdl n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ushort, uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tensor_buffer; z_tensor_buffer;
...@@ -1961,7 +1961,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1961,7 +1961,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, uint8_t,
ZDataType, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
......
...@@ -113,8 +113,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -113,8 +113,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64; static constexpr auto WaveSize = 64;
...@@ -134,17 +132,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -134,17 +132,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma; static constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2 static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time // get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16 static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
static constexpr auto DropoutMThread = DropoutTile; // 16 static constexpr auto DropoutStep = Number<DropoutStepValue>{}; // 1 2
static constexpr auto DropoutTilePerXdl = NPerXdl / DropoutTile; // 2
static constexpr auto DropoutStep = Number<DropoutStepValue>{}; // 1 2 4
static constexpr auto DropoutNRepeat =
Number<math::integer_divide_ceil(DropoutStep, DropoutTilePerXdl)>{}; // 1 1 2
static constexpr auto DropoutGroupPerTile =
Number<mfma.num_groups_per_blk / DropoutTilePerXdl>{}; // 2
static constexpr auto DropoutStepPerXdl =
Number<math::min(DropoutStep, DropoutTilePerXdl)>{}; // 1 2 2
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -152,51 +142,45 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -152,51 +142,45 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in gridwise copy // C desc for source in gridwise copy
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
{ {
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
const auto M0 = M / MPerBlock; const auto M0 = M / MPerBlock;
const auto N0 = N / (DropoutNRepeat * NPerXdl); const auto N0 = N / (DropoutStep * NPerXdl);
constexpr auto M1 = MXdlPerWave; constexpr auto M1 = MXdlPerWave;
constexpr auto N1 = DropoutNRepeat; constexpr auto N1 = DropoutStep;
constexpr auto M2 = Gemm0MWaves; constexpr auto M2 = Gemm0MWaves;
constexpr auto N2 = Gemm0NWaves; constexpr auto N2 = Gemm0NWaves;
constexpr auto M3 = DropoutTilePerXdl; constexpr auto M3 = DropoutTile;
constexpr auto N3 = DropoutStepPerXdl; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto M4 = DropoutTile; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N4 = DropoutGroupPerTile; constexpr auto N5 = mfma.group_size;
constexpr auto N5 = mfma.num_input_blks;
constexpr auto N6 = mfma.group_size;
return transform_tensor_descriptor( return transform_tensor_descriptor(
z_grid_desc_m_n, z_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)), make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4, N5, N6))), make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 8>{}, Sequence<1, 3, 5, 7, 9, 10, 11>{})); make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5()
{ {
constexpr auto M0 = MXdlPerWave; constexpr auto M0 = MXdlPerWave;
constexpr auto N0 = DropoutNRepeat; constexpr auto N0 = DropoutStep;
constexpr auto M1 = Gemm0MWaves; constexpr auto M1 = Gemm0MWaves;
constexpr auto N1 = Gemm0NWaves; constexpr auto N1 = Gemm0NWaves;
constexpr auto M2 = DropoutTilePerXdl; constexpr auto M2 = DropoutTile;
constexpr auto N2 = DropoutStepPerXdl; constexpr auto N2 = mfma.num_groups_per_blk;
constexpr auto M3 = DropoutTile; constexpr auto N3 = mfma.num_input_blks;
constexpr auto N3 = DropoutGroupPerTile; constexpr auto N4 = mfma.group_size;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
make_naive_tensor_descriptor_packed(make_tuple(M0, N0, M1, N1, M2, N2, M3, N3, N4, N5)); make_naive_tensor_descriptor_packed(make_tuple(M0, N0, M1, N1, M2, N2, N3, N4));
return z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; return z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
} }
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size) __host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
...@@ -317,7 +301,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -317,7 +301,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
const index_t z_block_bytes_end = const index_t z_block_bytes_end =
SharedMemTrait::z_shuffle_block_space_size * sizeof(ushort); SharedMemTrait::z_shuffle_block_space_size * sizeof(uint8_t);
return math::max(gemm0_bytes_end, return math::max(gemm0_bytes_end,
gemm1_bytes_end, gemm1_bytes_end,
...@@ -468,8 +452,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -468,8 +452,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6 = remove_cvref_t<decltype( using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(ZGridDesc_M_N{}))>; MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
struct SharedMemTrait struct SharedMemTrait
{ {
...@@ -507,10 +491,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -507,10 +491,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
// LDS allocation for Z shuffle in LDS // LDS allocation for Z shuffle in LDS
static constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = static constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(); GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
static constexpr auto z_shuffle_block_space_size = static constexpr auto z_shuffle_block_space_size =
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(); z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize();
}; };
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -538,12 +522,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -538,12 +522,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits, const uint8_t p_dropout_in_uint8_t,
FloatGemmAcc p_dropout_rescale, FloatGemmAcc p_dropout_rescale,
ck::philox& ph, ck::philox& ph,
const index_t z_random_matrix_offset, const index_t z_random_matrix_offset,
...@@ -894,7 +878,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -894,7 +878,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{ auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, p_dropout_rescale}; p_dropout_in_uint8_t, p_dropout_rescale};
const index_t num_gemm1_k_block_outer_loop = const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
...@@ -992,26 +976,22 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -992,26 +976,22 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0)); // register number 0)); // register number
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = // for blockwise copy constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
DropoutNRepeat, // NRepeat DropoutStep, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
I1,
DropoutStepPerXdl,
m2, m2,
DropoutGroupPerTile, n2,
n3, n3,
n4)); // RegisterNum n4)); // RegisterNum
constexpr auto z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = // for blockwise copy constexpr auto z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
DropoutNRepeat, // NRepeat DropoutStep, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
I1, n2,
DropoutStepPerXdl,
DropoutGroupPerTile,
n3, n3,
n4, // RegisterNum n4, // RegisterNum
m2)); m2));
...@@ -1020,180 +1000,150 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1020,180 +1000,150 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// z vgpr copy to global // z vgpr copy to global
// //
// z matrix threadwise desc // z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6 = constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId I1, // NBlockId
m0, // MRepeat m0, // MRepeat
DropoutNRepeat, // NRepeat DropoutStep, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
I1,
DropoutStepPerXdl,
m2, m2,
DropoutGroupPerTile, n2,
n3, n3,
n4)); // RegisterNum n4)); // RegisterNum
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(); GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto ZM0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I0); // 1 constexpr auto ZM0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); // 1
constexpr auto ZN0 = constexpr auto ZN0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); // 1 2
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I1); // 1 1 2 constexpr auto ZM1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); // 4
constexpr auto ZM1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I2); // 4 constexpr auto ZN1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); // 1
constexpr auto ZN1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I3); // 1 constexpr auto ZN2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); // 4
constexpr auto ZM2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I4); // 2 constexpr auto ZN3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); // 2
constexpr auto ZN2 = constexpr auto ZN4 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); // 4
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I5); // 1 2 2
constexpr auto ZM3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I6); // 16 constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
constexpr auto ZN3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I7); // 2 z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
constexpr auto ZN4 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I8); // 2 make_tuple(make_pass_through_transform(ZM0),
constexpr auto ZN5 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I9); // 4 make_pass_through_transform(ZN0),
make_pass_through_transform(ZM1),
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = make_pass_through_transform(ZN1),
transform_tensor_descriptor( make_unmerge_transform(make_tuple(ZN2, ZN3, ZN4)),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, make_merge_transform_v3_division_mod(make_tuple(ZN2, ZN3, ZN4))),
make_tuple(make_pass_through_transform(ZM0), make_tuple(Sequence<0>{},
make_pass_through_transform(ZN0), Sequence<1>{},
make_pass_through_transform(ZM1), Sequence<2>{},
make_pass_through_transform(ZN1), Sequence<3>{},
make_pass_through_transform(ZM2), Sequence<4>{},
make_pass_through_transform(ZN2), Sequence<5, 6, 7>{}),
make_unmerge_transform(make_tuple(ZM3 / ZN4 / ZN5, ZN4, ZN5)), make_tuple(Sequence<0>{},
make_merge_transform_v3_division_mod(make_tuple(ZN3, ZN4, ZN5))), Sequence<1>{},
make_tuple(Sequence<0>{}, Sequence<2>{},
Sequence<1>{}, Sequence<3>{},
Sequence<2>{}, Sequence<4, 5, 6>{},
Sequence<3>{}, Sequence<7>{}));
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7, 8, 9>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6, 7, 8>{},
Sequence<9>{}));
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ushort, uint8_t,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(),
true> true>
z_tensor_buffer; z_tensor_buffer;
z_tensor_buffer.Clear(); z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ushort*>(p_shared), static_cast<uint8_t*>(p_shared),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
auto z_tmp_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< auto z_tmp_thread_copy_vgpr_to_lds =
ushort, ThreadwiseTensorSliceTransfer_v1r3<uint8_t,
ushort, uint8_t,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(
tensor_operation::element_wise::PassThrough, z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
Sequence<m0, // MRepeat tensor_operation::element_wise::PassThrough,
DropoutNRepeat, // NRepeat Sequence<m0, // MRepeat
m1, // MWaveId DropoutStep, // NRepeat
n1, // NWaveId m1, // MWaveId
I1, n1, // NWaveId
DropoutStepPerXdl, m2,
m2, n2,
DropoutGroupPerTile, n3,
n3, n4>, // RegisterNum
n4>, // RegisterNum Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, 7, // DstVectorDim
9, // DstVectorDim 1, // DstScalarPerVector
1, // DstScalarPerVector InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::Set, 1, // DstScalarStrideInVector
1, // DstScalarStrideInVector true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, make_multi_index(0, // MRepeat
make_multi_index(0, // MRepeat 0, // NRepeat
0, // NRepeat wave_id[I0], // MWaveId
wave_id[I0], // MWaveId wave_id[I1], // NWaveId
wave_id[I1], // NWaveId wave_m_n_id[I1],
wave_m_n_id[I1] / DropoutMThread, 0,
0, wave_m_n_id[I0],
wave_m_n_id[I1] % DropoutMThread, 0),
0, tensor_operation::element_wise::PassThrough{}};
wave_m_n_id[I0],
0), auto z_shuffle_thread_copy_lds_to_vgpr =
tensor_operation::element_wise::PassThrough{}}; ThreadwiseTensorSliceTransfer_v2<uint8_t,
uint8_t,
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ushort, decltype(
ushort, z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), Sequence<m0, DropoutStep, m1, n1, n2, n3, n4, m2>,
decltype(z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<m0, 7,
DropoutNRepeat, 1,
m1, 1,
n1, true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
I1, make_multi_index(0, // MRepeat
DropoutStepPerXdl, 0, // NRepeat
DropoutGroupPerTile, wave_id[I0], // MWaveId
n3, wave_id[I1], // NWaveId
n4, 0,
m2>, wave_m_n_id[I0],
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, 0,
9, wave_m_n_id[I1])};
1,
1, auto z_thread_copy_vgpr_to_global =
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, ThreadwiseTensorSliceTransfer_v1r3<uint8_t,
make_multi_index(0, // MRepeat ZDataType,
0, // NRepeat decltype(
wave_id[I0], // MWaveId z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
wave_id[I1], // NWaveId decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
wave_m_n_id[I1] / DropoutMThread, tensor_operation::element_wise::PassThrough,
0, Sequence<I1, // MBlockId
0, I1, // NBlockID
wave_m_n_id[I0], m0, // MRepeat
0, DropoutStep, // NRepeat
wave_m_n_id[I1] % DropoutMThread)}; m1, // MWaveId
n1, // NWaveId
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< m2,
ushort, n2,
ZDataType, n3,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6), n4>,
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6), Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
tensor_operation::element_wise::PassThrough, 9, // DstVectorDim
Sequence<I1, // MBlockId 1, // DstScalarPerVector
I1, // NBlockID InMemoryDataOperationEnum::Set,
m0, // MRepeat 1, // DstScalarStrideInVector
DropoutNRepeat, // NRepeat true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
m1, // MWaveId make_multi_index(block_work_idx_m, // MBlockId
n1, // NWaveId 0, // NBlockId
I1, 0, // mrepeat
DropoutStepPerXdl, 0, // nrepeat
m2, wave_id[I0], // MWaveId
DropoutGroupPerTile, wave_id[I1], // NWaveId
n3, wave_m_n_id[I1],
n4>, 0,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11>, wave_m_n_id[I0],
11, // DstVectorDim 0),
1, // DstScalarPerVector tensor_operation::element_wise::PassThrough{}};
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1] / DropoutMThread,
0,
wave_m_n_id[I1] % DropoutMThread,
0,
wave_m_n_id[I0],
0),
tensor_operation::element_wise::PassThrough{}};
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -1321,8 +1271,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1321,8 +1271,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
blockwise_softmax.Run(acc_thread_buf, workspace_buf); blockwise_softmax.Run(acc_thread_buf, workspace_buf);
constexpr auto iterator_offset = Number<8 * DropoutStep>{}; constexpr auto iterator_offset = Number<16 * DropoutStep>{};
constexpr auto iterator_step = Number<n0 * n1 * n2 * n3 * n4 / 8 / DropoutStep>{}; constexpr auto iterator_step = Number<m0 * n0 * n1 * n2 * n3 * n4 / 16 / DropoutStep>{};
if constexpr(IsDropout) // dropout if constexpr(IsDropout) // dropout
{ {
...@@ -1343,18 +1293,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1343,18 +1293,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
decltype(DropoutTile)>( decltype(DropoutTile)>(
ph, global_elem_id, z_tensor_buffer); ph, global_elem_id, z_tensor_buffer);
z_tmp_thread_copy_vgpr_to_lds.Run( z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), z_tensor_buffer,
z_tensor_buffer, z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_block_buf);
z_block_buf);
z_shuffle_thread_copy_lds_to_vgpr.Run( z_shuffle_thread_copy_lds_to_vgpr.Run(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
z_block_buf, z_block_buf,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer); z_tensor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
...@@ -1367,14 +1316,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1367,14 +1316,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
if(p_z_grid && (gemm1_n_block_data_idx_on_grid == 0)) if(p_z_grid && (gemm1_n_block_data_idx_on_grid == 0))
{ {
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
}); });
} }
......
...@@ -84,6 +84,19 @@ class philox ...@@ -84,6 +84,19 @@ class philox
out_tmp[3] = tmp_ph.w; out_tmp[3] = tmp_ph.w;
} }
__device__ void get_random_16x8(uint8_t* out, const unsigned long long subsequence)
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp_ph.x;
out_tmp[1] = tmp_ph.y;
out_tmp[2] = tmp_ph.z;
out_tmp[3] = tmp_ph.w;
}
__device__ void get_random_4x16(ushort* out, const unsigned long long subsequence) __device__ void get_random_4x16(ushort* out, const unsigned long long subsequence)
{ {
uint4 tmp_ph; uint4 tmp_ph;
......
...@@ -25,19 +25,19 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -25,19 +25,19 @@ struct ReferenceDropout : public device::BaseOperator
Argument(const Tensor<RefDataType>& ref, Argument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in, const Tensor<InDataType>& in,
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
RefDataType p_dropout_in_16bits, RefDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
: ref_(ref), : ref_(ref),
in_(in), in_(in),
out_(out), out_(out),
p_dropout_in_16bits_(p_dropout_in_16bits), p_dropout_in_uint8_t_(p_dropout_in_uint8_t),
rp_dropout_(rp_dropout) rp_dropout_(rp_dropout)
{ {
} }
const Tensor<RefDataType>& ref_; const Tensor<RefDataType>& ref_;
const Tensor<InDataType>& in_; const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_; Tensor<OutDataType>& out_;
RefDataType p_dropout_in_16bits_; RefDataType p_dropout_in_uint8_t_;
float rp_dropout_; float rp_dropout_;
}; };
...@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
{ {
arg.out_.ForEach([&](auto& self, auto idx) { arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = self(idx) =
arg.ref_(idx) <= arg.p_dropout_in_16bits_ arg.ref_(idx) <= arg.p_dropout_in_uint8_t_
? ck::type_convert<OutDataType>(ck::type_convert<float>(arg.in_(idx)) * ? ck::type_convert<OutDataType>(ck::type_convert<float>(arg.in_(idx)) *
ck::type_convert<float>(arg.rp_dropout_)) ck::type_convert<float>(arg.rp_dropout_))
: 0; : 0;
...@@ -74,10 +74,10 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -74,10 +74,10 @@ struct ReferenceDropout : public device::BaseOperator
static auto MakeArgument(const Tensor<RefDataType>& ref, static auto MakeArgument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in, const Tensor<InDataType>& in,
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
RefDataType p_dropout_in_16bits, RefDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
return Argument{ref, in, out, p_dropout_in_16bits, rp_dropout}; return Argument{ref, in, out, p_dropout_in_uint8_t, rp_dropout};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment