Commit 14822d71 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge

parents 5b02dfaf 80560ef2
...@@ -139,8 +139,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -139,8 +139,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -315,8 +315,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -315,8 +315,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
CGridDesc_M_N{}))>; CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
...@@ -634,10 +634,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -634,10 +634,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, FloatCShuffle, // typename SrcData,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype( decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim, 5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
......
...@@ -142,8 +142,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -142,8 +142,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -323,13 +323,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -323,13 +323,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
CGridDesc_M_N{}))>; CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C0GridDesc_M_N{}))>; C0GridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
...@@ -654,12 +654,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -654,12 +654,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
FloatC, // typename Src0Data, FloatC, // typename Src0Data,
FloatC, // typename Src1Data, FloatC, // typename Src1Data,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype( decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype( decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim, 5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
......
...@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
CGridDesc_M_N{}))>; CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C0GridDesc_M_N{}))>; C0GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C1GridDesc_M_N{}))>; C1GridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
...@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
FloatC, // typename Src1Data, FloatC, // typename Src1Data,
FloatC, // typename Src2Data, FloatC, // typename Src2Data,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype( decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype( decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim, 5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
......
...@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple( using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple(
Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{}))); Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{})));
using ThreadwiseWolfordDescReduce = decltype( using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<DimSubBlocks * DimThreadSize>{}))); make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
using ThreadwiseWelford = using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>; ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>;
......
...@@ -87,9 +87,9 @@ struct GridwiseNormalizationSplitK1st ...@@ -87,9 +87,9 @@ struct GridwiseNormalizationSplitK1st
int left_kPerBlock = math::integer_divide_ceil(k, kGridSize); int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1); int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
int kPerThread = kRightmostBlock < K_BlockTileSize int kPerThread = kRightmostBlock < K_BlockTileSize
? 0 ? 0
: KThreadSliceSize * (kRightmostBlock / K_BlockTileSize); : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize; int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0) if(kPerBlockTail > 0)
{ {
......
...@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths = constexpr auto ordered_src_access_lengths =
......
...@@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation // apply pointwise operation
static_for<0, ScalarPerVector, 1>{}([&](auto i) { static_for<0, ScalarPerVector, 1>{}([&](auto i) {
SrcData v; DstData v;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_vector_container.template AsType<SrcData>()[i]); element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
// apply type convert // apply type convert
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v); dst_vector_container.template AsType<DstData>()(i) = v;
}); });
const bool is_dst_valid = const bool is_dst_valid =
......
...@@ -236,8 +236,6 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -236,8 +236,6 @@ struct TransformConvBwdDataToGemm_v1
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const index_t AK0 = K / AK1;
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const auto out_grid_desc = const auto out_grid_desc =
make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>( make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
...@@ -247,6 +245,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -247,6 +245,8 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const index_t AK0 = math::integer_divide_ceil(K, AK1);
// A: output tensor // A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc, out_grid_desc,
...@@ -308,6 +308,9 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -308,6 +308,9 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t AK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1);
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
{ {
// A: output tensor // A: output tensor
...@@ -332,7 +335,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -332,7 +335,7 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc = const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
...@@ -340,7 +343,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -340,7 +343,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -352,21 +355,28 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -352,21 +355,28 @@ struct TransformConvBwdDataToGemm_v1
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5, 6>{})); Sequence<5>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
make_pass_through_transform(AK1)), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc = const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc, out_gemmk_gemmmraw_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1), make_tuple(AK1, GemmMPerBlock),
Sequence<false, DoPadGemmM, false>{}); Sequence<true, DoPadGemmM>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc; return out_gemmak0_gemmm_gemmak1_grid_desc;
} }
...@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1
Sequence<7>{})); Sequence<7>{}));
const auto const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc = out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
...@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -437,22 +447,29 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -437,22 +447,29 @@ struct TransformConvBwdDataToGemm_v1
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}, Sequence<5>{},
Sequence<6>{}, Sequence<6>{},
Sequence<7, 8>{})); Sequence<7>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, AK0)), make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))),
make_pass_through_transform(AK1)), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc = const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc, out_gemmk_gemmmraw_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1), make_tuple(AK1, GemmMPerBlock),
Sequence<false, DoPadGemmM, false>{}); Sequence<true, DoPadGemmM>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc; return out_gemmak0_gemmm_gemmak1_grid_desc;
} }
...@@ -505,8 +522,6 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -505,8 +522,6 @@ struct TransformConvBwdDataToGemm_v1
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const index_t BK0 = K / BK1;
// assume packed // assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d // k_y_x_c for 2d or k_z_y_x_c for 3d
const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C); const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C);
...@@ -515,6 +530,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -515,6 +530,8 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const index_t BK0 = math::integer_divide_ceil(K, BK1);
// B: weight tensor // B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
...@@ -551,6 +568,9 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -551,6 +568,9 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t BK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1);
// B weight tensor // B weight tensor
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
{ {
...@@ -566,43 +586,47 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -566,43 +586,47 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc = const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor( wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, make_tuple(make_pass_through_transform(K),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_freeze_transform(i_ytilde),
make_freeze_transform(i_ytilde), make_freeze_transform(i_xtilde),
make_freeze_transform(i_xtilde), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{},
make_tuple(Sequence<0>{}, Sequence<1>{},
Sequence<1>{}, Sequence<3>{},
Sequence<3>{}, Sequence<2>{},
Sequence<2>{}, Sequence<4>{},
Sequence<4>{}, Sequence<5>{}),
Sequence<5>{}), make_tuple(Sequence<0>{},
make_tuple(Sequence<0, 1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<3>{}));
Sequence<4>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc, wei_k_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_pass_through_transform(C), make_pass_through_transform(C)),
make_pass_through_transform(BK1)), make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, wei_gemmk_gemmnraw_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), make_tuple(BK1, GemmNPerBlock),
GemmNPerBlock, Sequence<true, DoPadGemmN>{});
BK1),
Sequence<false, DoPadGemmN, false>{}); const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemmn_gemmbk1_grid_desc; return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
} }
...@@ -631,10 +655,10 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -631,10 +655,10 @@ struct TransformConvBwdDataToGemm_v1
Sequence<5, 6>{}, Sequence<5, 6>{},
Sequence<7>{})); Sequence<7>{}));
const auto wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc = const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_pass_through_transform(K),
make_slice_transform(ZDot, I0, ZDotSlice), make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
...@@ -650,33 +674,37 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -650,33 +674,37 @@ struct TransformConvBwdDataToGemm_v1
Sequence<4>{}, Sequence<4>{},
Sequence<6>{}, Sequence<6>{},
Sequence<7>{}), Sequence<7>{}),
make_tuple(Sequence<0, 1>{}, make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<4>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<5>{})); Sequence<4>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc, wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple( make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, BK0)), make_pass_through_transform(C)),
make_pass_through_transform(C), make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_pass_through_transform(BK1)), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = const auto wei_gemmk_gemm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, wei_gemmk_gemmnraw_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), make_tuple(BK1, GemmNPerBlock),
GemmNPerBlock, Sequence<true, DoPadGemmN>{});
BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc; const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemm_padded_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemm_gemmbk1_grid_desc;
} }
else else
{ {
......
...@@ -94,8 +94,8 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h ...@@ -94,8 +94,8 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
const vector_type<half_t, 2> b_vector{b}; const vector_type<half_t, 2> b_vector{b};
static_for<0, 2, 1>{}([&](auto i) { static_for<0, 2, 1>{}([&](auto i) {
c += type_convert<int32_t>(a_vector.AsType<half_t>()[i]) * c += type_convert<float>(a_vector.AsType<half_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<half_t>()[i]); type_convert<float>(b_vector.AsType<half_t>()[i]);
}); });
#endif #endif
} }
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
/* the configured version and settings for miopen- Composable Kernel */
#ifndef CK_VERSION_H_
#define CK_VERSION_H_
// clang-format off
#define CK_VERSION @CMAKE_PROJECT_VERSION@
#define CK_VERSION_MAJOR @CMAKE_PROJECT_VERSION_MAJOR@
#define CK_VERSION_MINOR @CMAKE_PROJECT_VERSION_MINOR@
#define CK_VERSION_PATCH @CMAKE_PROJECT_VERSION_PATCH@
#define CK_COMMIT_ID @COMMIT_ID@
// clang-format on
#endif
...@@ -92,11 +92,11 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -92,11 +92,11 @@ struct ReferenceGemm : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
AccDataType v_c; CDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c); arg.c_m_n_(m, n) = v_c;
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
......
...@@ -32,6 +32,8 @@ using F32_Tuple = ck::Tuple<F32>; ...@@ -32,6 +32,8 @@ using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>; using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>; using I32_F32_Tuple = ck::Tuple<I32, F32>;
using F32_F32_Tuple = ck::Tuple<F32, F32>;
// GEMM layout // GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -95,10 +97,11 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; ...@@ -95,10 +97,11 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu; using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu; using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
using AddBias = ck::tensor_operation::element_wise::AddBias; using Add = ck::tensor_operation::element_wise::Add;
template <typename Activation> template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>; using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __bf16__ #ifdef CK_ENABLE_BF16
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
...@@ -37,7 +37,7 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( ...@@ -37,7 +37,7 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif #endif
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -58,7 +58,7 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( ...@@ -58,7 +58,7 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif #endif
#ifdef __fp32__ #ifdef CK_ENABLE_FP32
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
...@@ -79,7 +79,7 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( ...@@ -79,7 +79,7 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
DeviceBatchedGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif #endif
#ifdef __int8__ #ifdef CK_ENABLE_INT8
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemm<Col, std::vector<std::unique_ptr<DeviceBatchedGemm<Col,
Row, Row,
...@@ -154,7 +154,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -154,7 +154,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp32__ #ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> && if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>) is_same_v<CDataType, float>)
{ {
...@@ -180,7 +180,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -180,7 +180,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
} }
} }
#endif #endif
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
{ {
...@@ -206,7 +206,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -206,7 +206,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
} }
} }
#endif #endif
#ifdef __bf16__ #ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> && if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>) is_same_v<CDataType, bhalf_t>)
{ {
...@@ -232,7 +232,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -232,7 +232,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
} }
} }
#endif #endif
#ifdef __int8__ #ifdef CK_ENABLE_INT8
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> && if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>) is_same_v<CDataType, int8_t>)
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu; using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add; using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
...@@ -59,7 +59,7 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_ ...@@ -59,7 +59,7 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
#endif #endif
#ifdef __bf16__ #ifdef CK_ENABLE_BF16
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
...@@ -148,7 +148,7 @@ struct DeviceOperationInstanceFactory< ...@@ -148,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> && is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> &&
Acc0BiasDataType::Size() == 1 && Acc0BiasDataType::Size() == 1 &&
...@@ -166,7 +166,7 @@ struct DeviceOperationInstanceFactory< ...@@ -166,7 +166,7 @@ struct DeviceOperationInstanceFactory<
} }
} }
#endif #endif
#ifdef __bf16__ #ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> && else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> && is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> &&
Acc0BiasDataType::Size() == 1 && Acc0BiasDataType::Size() == 1 &&
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row, std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col, Col,
......
...@@ -19,7 +19,7 @@ namespace ck { ...@@ -19,7 +19,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances( void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row, Row,
...@@ -124,7 +124,7 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instan ...@@ -124,7 +124,7 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instan
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef __int8__ #ifdef CK_ENABLE_INT8
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances( void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row, Row,
...@@ -263,7 +263,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -263,7 +263,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
...@@ -297,7 +297,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -297,7 +297,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
} }
} }
#endif #endif
#ifdef __int8__ #ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> && else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, int8_t>) is_same_v<EDataType, int8_t>)
{ {
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
...@@ -59,7 +59,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g ...@@ -59,7 +59,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
#endif #endif
#ifdef __bf16__ #ifdef CK_ENABLE_BF16
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
...@@ -148,7 +148,7 @@ struct DeviceOperationInstanceFactory< ...@@ -148,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>) is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{ {
...@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory< ...@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory<
} }
} }
#endif #endif
#ifdef __bf16__ #ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> && else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>) is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment