Commit 14822d71 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge

parents 5b02dfaf 80560ef2
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -147,7 +147,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -147,7 +147,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Atomic) StreamKReductionStrategy::Atomic)
{ {
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
0,
karg.M * karg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
grid_dims, grid_dims,
......
...@@ -248,10 +248,12 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -248,10 +248,12 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>; CShuffleBlockTransferScalarPerVector_NPerBlock>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
......
...@@ -355,9 +355,13 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -355,9 +355,13 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -400,14 +404,18 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -400,14 +404,18 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
LoopSched>; LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
struct GroupedContractionBlock2ETileMap struct GroupedContractionBlock2ETileMap
{ {
......
...@@ -355,6 +355,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -355,6 +355,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ABDataType, // TODO: distinguish A/B datatype
ABDataType, // TODO: distinguish A/B datatype
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -422,10 +424,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -422,10 +424,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{}));
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})); decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( DsGridDesc_M_N{}));
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}));
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
......
...@@ -381,8 +381,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -381,8 +381,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
} }
// desc for problem definition // desc for problem definition
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
MakeAGridDescriptor_AK0_M_AK1<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
......
...@@ -320,8 +320,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -320,8 +320,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
} }
// desc for problem definition // desc for problem definition
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
MakeAGridDescriptor_AK0_M_AK1<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>; using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>;
......
...@@ -446,8 +446,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -446,8 +446,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
} }
using AGridDesc_M_K = remove_cvref_t<decltype( using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>({}, {}))>;
using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>; using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>;
...@@ -507,10 +507,12 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -507,10 +507,12 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
RThreadTransferDstScalarPerVector_MPerBlock, RThreadTransferDstScalarPerVector_MPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
......
...@@ -245,8 +245,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -245,8 +245,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
// desc for problem definition // desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype( using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
......
...@@ -361,15 +361,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -361,15 +361,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
} }
// desc for problem definition // desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype( using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using ComputeDataType = ADataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -412,14 +416,18 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -412,14 +416,18 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
LoopSched>; LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
......
...@@ -735,12 +735,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -735,12 +735,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1] ? device_arg.a_mz_kz_strides_[1]
: device_arg.a_mz_kz_strides_[0]; : device_arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2 const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
? device_arg.b_nz_kz_strides_[1] ? device_arg.b_nz_kz_strides_[1]
: device_arg.b_nz_kz_strides_[0]; : device_arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? device_arg.b1_nz_kz_strides_[1] ? device_arg.b1_nz_kz_strides_[1]
: device_arg.b1_nz_kz_strides_[0]; : device_arg.b1_nz_kz_strides_[0];
......
...@@ -228,9 +228,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -228,9 +228,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = ADataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -272,14 +276,18 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -272,14 +276,18 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
struct GroupedGemmBlock2ETileMap struct GroupedGemmBlock2ETileMap
{ {
......
...@@ -114,7 +114,8 @@ template <typename ALayout, ...@@ -114,7 +114,8 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
// Current implementation does not support multiple D fusions. // Current implementation does not support multiple D fusions.
enable_if_t<AK1 == BK1 && is_same_v<DsLayout, ck::Tuple<>> && enable_if_t<AK1 == BK1 && is_same_v<DsLayout, ck::Tuple<>> &&
is_same_v<DsDataType, ck::Tuple<>>, is_same_v<DsDataType, ck::Tuple<>>,
...@@ -142,7 +143,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -142,7 +143,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
AccDataType, AccDataType,
EDataType, EDataType,
ALayout, ALayout,
...@@ -182,7 +184,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -182,7 +184,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched, LoopSched,
PipelineVersion::v1>; PipelineVer>;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit = using Block2ETileMapKSplit =
...@@ -421,8 +423,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -421,8 +423,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
for(const auto& trans_arg : arg.gemm_kernel_args_) for(const auto& trans_arg : arg.gemm_kernel_args_)
{ {
const auto& karg = trans_arg.karg_; const auto& karg = trans_arg.karg_;
hip_check_error( hip_check_error(hipMemsetAsync(karg.p_c_grid,
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(EDataType))); 0,
karg.M * karg.N * sizeof(EDataType),
stream_config.stream_id_));
} }
} }
......
...@@ -617,10 +617,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -617,10 +617,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AKB_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AKB_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(AGridDesc_M_K{}, 1))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(
using BGridDesc_BKB_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}, 1))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(BGridDesc_N_K{}, 1))>; using BGridDesc_BKB_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(
BGridDesc_N_K{}, 1))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
...@@ -886,11 +888,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -886,11 +888,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap, typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
hipGetErrorString(hipMemset( hipGetErrorString(hipMemsetAsync(
arg.p_e_grid_, arg.p_e_grid_,
0, 0,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(EDataType))); sizeof(EDataType),
stream_config.stream_id_));
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -36,6 +36,13 @@ struct Add ...@@ -36,6 +36,13 @@ struct Add
y = x0 + type_convert<half_t>(x1); y = x0 + type_convert<half_t>(x1);
}; };
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
{
y = type_convert<half_t>(x0 + x1);
};
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
......
...@@ -195,6 +195,51 @@ struct AddMultiply ...@@ -195,6 +195,51 @@ struct AddMultiply
} }
}; };
// C = A * B
// E = C x D0 + D1
struct MultiplyAdd
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (c * d0) + d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = type_convert<half_t>(c) * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const float y = c * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, float, float>(half_t& e,
const float& c,
const float& d0,
const float& d1) const
{
const float y = c * d0 + d1;
e = y;
}
};
// E = FastGelu(C + D0 + D1) // E = FastGelu(C + D0 + D1)
struct AddAddFastGelu struct AddAddFastGelu
{ {
......
...@@ -40,21 +40,21 @@ struct PassThrough ...@@ -40,21 +40,21 @@ struct PassThrough
} }
template <> template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const __host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{ {
y = x; y = type_convert<half_t>(x);
} }
template <> template <>
__host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{ {
y = x; y = x;
} }
template <> template <>
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const __host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{ {
y = type_convert<half_t>(x); y = x;
} }
template <> template <>
...@@ -132,34 +132,6 @@ struct PassThrough ...@@ -132,34 +132,6 @@ struct PassThrough
} }
}; };
struct AddBias
{
template <typename E, typename C, typename D0>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0) const;
template <>
__host__ __device__ void
operator()<ck::half_t, float, float>(ck::half_t& e, const float& c, const float& d0) const
{
e = c + d0;
}
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t, float>(ck::half_t& e,
const ck::half_t& c,
const float& d0) const
{
e = c + d0;
}
template <>
__host__ __device__ void
operator()<float, float, float>(float& e, const float& c, const float& d0) const
{
e = c + d0;
}
};
struct UnaryConvert struct UnaryConvert
{ {
template <typename Y, typename X> template <typename Y, typename X>
......
...@@ -136,8 +136,8 @@ struct GridwiseMultiblockBatchNormForward ...@@ -136,8 +136,8 @@ struct GridwiseMultiblockBatchNormForward
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadwiseWelford1 = using ThreadwiseWelford1 =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>; ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
......
...@@ -118,8 +118,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -118,8 +118,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......
...@@ -121,8 +121,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -121,8 +121,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......
...@@ -115,8 +115,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -115,8 +115,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment