Commit f537f83f authored by root's avatar root
Browse files

add multi_d

parent 3afb2f74
...@@ -2,6 +2,9 @@ add_custom_target(example_gemm_add_add_fastgelu_xdl) ...@@ -2,6 +2,9 @@ add_custom_target(example_gemm_add_add_fastgelu_xdl)
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16_int8 gemm_add_add_fastgelu_xdl_bf16_int8.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16_int8)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
...@@ -24,4 +27,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -24,4 +27,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multi_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -22,6 +22,7 @@ namespace device { ...@@ -22,6 +22,7 @@ namespace device {
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout,
typename CLayout, typename CLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
...@@ -67,6 +68,7 @@ template <typename ALayout, ...@@ -67,6 +68,7 @@ template <typename ALayout,
typename ComputeTypeB = ComputeTypeA> typename ComputeTypeB = ComputeTypeA>
struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
BLayout, BLayout,
DsLayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
...@@ -87,7 +89,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -87,7 +89,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
BDataType, BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
Tuple<>, DsDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -425,6 +427,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -425,6 +427,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be Odd or Even // Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
#if 0
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...@@ -449,6 +452,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -449,6 +452,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
} }
} }
else else
#endif
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -474,6 +478,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -474,6 +478,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
} }
else else
{ {
#if 0
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...@@ -498,6 +503,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -498,6 +503,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
} }
} }
else else
#endif
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -527,6 +533,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -527,6 +533,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number always 1 // Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
#if 0
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
const auto kernel = const auto kernel =
...@@ -537,6 +544,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -537,6 +544,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
Run(kernel); Run(kernel);
} }
else else
#endif
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
...@@ -589,10 +597,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -589,10 +597,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const void* p_a,
const BDataType* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
CDataType* p_c, void* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -600,15 +608,14 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -600,15 +608,14 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{static_cast<const ADataType*>(p_a),
p_b, static_cast<const BDataType*>(p_b),
p_ds, p_ds,
p_c, static_cast<CDataType*>(p_c),
M, M,
N, N,
K, K,
...@@ -616,7 +623,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -616,7 +623,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB, StrideB,
StrideDs, StrideDs,
StrideC, StrideC,
KBatch, 1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; c_element_op};
...@@ -636,7 +643,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -636,7 +643,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB, index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs, std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
...@@ -652,7 +658,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -652,7 +658,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB, StrideB,
StrideDs, StrideDs,
StrideC, StrideC,
KBatch, 1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
......
...@@ -25,7 +25,6 @@ template <typename ALayout, ...@@ -25,7 +25,6 @@ template <typename ALayout,
typename CLayout, typename CLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename CDataType, typename CDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
...@@ -70,14 +69,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -70,14 +69,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
ALayout, ALayout,
...@@ -589,16 +585,16 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -589,16 +585,16 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
std::array<const void*, NumDTensor> p_ds,
CDataType* p_c, CDataType* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch, index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -607,14 +603,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -607,14 +603,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_ds, std::array<const void*, 0>{},
p_c, p_c,
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideDs, std::array<index_t, 0>{},
StrideC, StrideC,
KBatch, KBatch,
a_element_op, a_element_op,
...@@ -627,14 +623,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -627,14 +623,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch, index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -643,14 +637,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -643,14 +637,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
p_ds, std::array<const void*, 0>{},
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideDs, std::array<index_t, 0>{},
StrideC, StrideC,
KBatch, KBatch,
a_element_op, a_element_op,
......
...@@ -48,7 +48,10 @@ __global__ void ...@@ -48,7 +48,10 @@ __global__ void
karg.p_ds_grid, karg.p_ds_grid,
karg.p_c_grid, karg.p_c_grid,
p_shared, p_shared,
karg); karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...@@ -82,7 +85,10 @@ __global__ void ...@@ -82,7 +85,10 @@ __global__ void
karg.p_c_grid, karg.p_c_grid,
p_shared_0, p_shared_0,
p_shared_1, p_shared_1,
karg); karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...@@ -134,7 +140,9 @@ template <typename ALayout, ...@@ -134,7 +140,9 @@ template <typename ALayout,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA> typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType>
struct GridwiseGemm_xdl_cshuffle_v3 struct GridwiseGemm_xdl_cshuffle_v3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -167,9 +175,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -167,9 +175,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
static constexpr index_t KPack = static constexpr index_t KPack = math::max(
math::max(math::lcm(AK1Number, BK1Number), math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -662,9 +670,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -662,9 +670,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
// in some cases. // in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1
? 1 ? 1
: 32 * 4 / KPerBlock / sizeof(ADataType); : 32 * 4 / KPerBlock / sizeof(LDSTypeA);
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple( make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number), AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
...@@ -710,20 +718,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -710,20 +718,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto KThreadRead = 64 / MPerXdl; constexpr auto KThreadRead = 64 / MPerXdl;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
? 1 ? 1
: 128 / (AK1Number * M0 * sizeof(ADataType)); : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
constexpr auto KThreadReadPerm = constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead; : KThreadRead;
// 1<=mpair<=n0 // 1<=mpair<=n0
constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
? 1 ? 1
: ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
? M0 ? M0
: 128 / (AK1Number * MPerXdl * sizeof(ADataType))); : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{}, make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
...@@ -798,9 +806,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -798,9 +806,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
// NLdsLayer * K0 as logical Bank // NLdsLayer * K0 as logical Bank
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1
? 1 ? 1
: 32 * 4 / KPerBlock / sizeof(BDataType); : 32 * 4 / KPerBlock / sizeof(LDSTypeB);
; ;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple( make_tuple(
...@@ -844,20 +852,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -844,20 +852,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto KThreadRead = 64 / NPerXdl; constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128)
? 1 ? 1
: 128 / (BK1Number * N0 * sizeof(BDataType)); : 128 / (BK1Number * N0 * sizeof(LDSTypeB));
constexpr auto KThreadReadPerm = constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead; : KThreadRead;
// 1<=npair<=n0 // 1<=npair<=n0
constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128)
? 1 ? 1
: ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0
? N0 ? N0
: 128 / (BK1Number * NPerXdl * sizeof(BDataType))); : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{}, make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
...@@ -940,8 +948,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -940,8 +948,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
BlkGemmPipelineVer, BlkGemmPipelineVer,
BlkGemmPipeSched, BlkGemmPipeSched,
BlockSize, BlockSize,
ADataType, LDSTypeA,
BDataType, LDSTypeB,
ComputeTypeA, ComputeTypeA,
AccDataType, AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
...@@ -983,8 +991,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -983,8 +991,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(ADataType) + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) +
b_block_space_size_aligned * sizeof(BDataType)), b_block_space_size_aligned * sizeof(LDSTypeB)),
c_block_size * sizeof(CShuffleDataType)); c_block_size * sizeof(CShuffleDataType));
} }
...@@ -1203,7 +1211,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1203,7 +1211,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
DsGridPointer& p_ds_grid, DsGridPointer& p_ds_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
void* p_shared, void* p_shared,
const Problem& problem) const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{ {
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
...@@ -1223,10 +1234,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1223,10 +1234,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N] // divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
...@@ -1270,7 +1277,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1270,7 +1277,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ADataType, ADataType,
ADataType, LDSTypeA,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -1301,7 +1308,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1301,7 +1308,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BDataType, BDataType,
BDataType, LDSTypeB,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -1328,11 +1335,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1328,11 +1335,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
// Cast after lds // Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared) + static_cast<LDSTypeB*>(p_shared) +
a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
...@@ -1721,7 +1728,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1721,7 +1728,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
CDataType* p_c_grid, CDataType* p_c_grid,
void* p_shared_0, void* p_shared_0,
void* p_shared_1, void* p_shared_1,
const Problem& problem) const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{ {
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
...@@ -1741,10 +1751,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1741,10 +1751,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N] // divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
...@@ -1788,7 +1794,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1788,7 +1794,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ADataType, ADataType,
ADataType, LDSTypeA,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -1819,7 +1825,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1819,7 +1825,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BDataType, BDataType,
BDataType, LDSTypeB,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -1845,19 +1851,19 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1845,19 +1851,19 @@ struct GridwiseGemm_xdl_cshuffle_v3
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<LDSTypeA*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared_0) + static_cast<LDSTypeB*>(p_shared_0) +
a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<LDSTypeA*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared_1) + static_cast<LDSTypeB*>(p_shared_1) +
a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
......
...@@ -116,7 +116,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) ...@@ -116,7 +116,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment