Commit 97b371bd authored by Jing Zhang's avatar Jing Zhang
Browse files

pipeline_v2

parent 49180fd6
......@@ -46,7 +46,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
// clang-format off
......@@ -54,9 +54,46 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
// problem_size.Ms = {
// 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(157);
problem_size.Ns.push_back(768);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
return !run_grouped_gemm(problem_size, config);
}
......@@ -54,7 +54,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSpl
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
#include "run_grouped_gemm_example.inc"
......
......@@ -17,6 +17,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace ck {
namespace tensor_operation {
......@@ -239,7 +240,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumPrefetch, // NumGemmKPrefetchStage
NumPrefetch,
BlockSize,
MPerBlock,
NPerBlock,
......@@ -270,7 +271,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
LoopSched,
PipelineVersion::v1>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
......
......@@ -182,7 +182,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
CDEBlockTransferScalarPerVector_NPerBlock,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched,
PipelineVersion::v2>;
PipelineVersion::v1>;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
......
......@@ -67,7 +67,7 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v2>
struct GridwiseGemmMultipleD_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
......
......@@ -34,7 +34,7 @@ constexpr auto GridwiseGemmPipeline_Selector()
}
else if constexpr(PipelineVer == PipelineVersion::v2)
{
return GridwiseGemmPipeline_v2{};
return GridwiseGemmPipeline_v2<NumPrefetch>{};
}
else
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
struct GridwiseGemmPipeline_v2
// N-stage prefetch
template <index_t NumPrefetch>
struct GridwiseGemmPipeline_v2;
// 1-stage prefetch
template <>
struct GridwiseGemmPipeline_v2<1>
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
return num_loop >= 2;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return (num_loop / 2) > 1;
return num_loop > 2;
}
template <bool HasMainLoop,
......@@ -79,10 +82,6 @@ struct GridwiseGemmPipeline_v2
do
{
#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
__builtin_amdgcn_iglp_opt(CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT);
#endif
block_sync_lds();
// GEMM i
......@@ -129,4 +128,512 @@ struct GridwiseGemmPipeline_v2
}
};
// 2-stage prefetch
template <>
struct GridwiseGemmPipeline_v2<2>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop >= 3;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop >= 5;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// global read 0
static_for<0, 2, 1>{}([&](auto i_pre) {
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_pre>{});
s_nop();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_pre>{});
s_nop();
// move to i_pre + 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
});
// Initialize C
c_thread_buf.Clear();
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// global Read 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
// LDS write 0
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// global Read 2
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
index_t i = 0;
// main body
if constexpr(HasMainLoop)
{
do
{
static_for<0, 2, 1>{}([&](auto i_main) {
block_sync_lds();
// GEMM i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_blockwise_copy.RunWrite(
a_block_desc, a_block_buf, Number<(i_main + 1) % 2>{});
// global read i + 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<(i_main + 1) % 2>{});
// LDS write i + 1
b_blockwise_copy.RunWrite(
b_block_desc, b_block_buf, Number<(i_main + 1) % 2>{});
// global read i + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<(i_main + 1) % 2>{});
});
i += 2;
} while(i < (num_loop - 4));
}
// tail
if(i == num_loop - 3)
{
block_sync_lds();
// GEMM num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
block_sync_lds();
// GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
block_sync_lds();
// GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// tail
else if(i == num_loop - 4)
{
static_for<0, 4, 1>{}([&](auto i_res) {
block_sync_lds();
// GEMM num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
if constexpr(i_res < 3)
{
block_sync_lds();
if constexpr(i_res < 1)
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<(i_res + 1) % 2>{});
if constexpr(i_res < 1)
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<(i_res + 1) % 2>{});
if constexpr(i_res < 1)
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
}
});
}
}
};
// 3-stage prefetch
template <>
struct GridwiseGemmPipeline_v2<3>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop >= 4;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop >= 7;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
static_for<0, 3, 1>{}([&](auto i_pre) {
// global read i_pre
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_pre>{});
s_nop();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_pre>{});
s_nop();
// move to i_pre + 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
});
// Initialize C
c_thread_buf.Clear();
// LDS write i_main
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// global Read i_main + 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
// LDS write i_main
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// global Read i_main + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
index_t i = 0;
// main body
if constexpr(HasMainLoop)
{
do
{
static_for<0, 3, 1>{}([&](auto i_main) {
block_sync_lds();
// GEMM i_main
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i_main + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i_main
a_blockwise_copy.RunWrite(
a_block_desc, a_block_buf, Number<(i_main + 1) % 3>{});
// global Read i_main + 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<(i_main + 1) % 3>{});
// LDS write i_main
b_blockwise_copy.RunWrite(
b_block_desc, b_block_buf, Number<(i_main + 1) % 3>{});
// global Read i_main + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<(i_main + 1) % 3>{});
});
i += 3;
} while(i < (num_loop - 6));
}
// tail
if(i == num_loop - 6)
{
static_for<0, 6, 1>{}([&](auto i_res) {
block_sync_lds();
// GEMM num_loop - 3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
if constexpr(i_res < 5)
{
block_sync_lds();
if constexpr(i_res < 2)
{
// move to i_res + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
// LDS write i_res
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<(i_res + 1) % 3>{});
if constexpr(i_res < 2)
a_blockwise_copy.RunRead(
a_grid_desc, a_grid_buf, Number<(i_res + 1) % 3>{});
// LDS write i_res
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<(i_res + 1) % 3>{});
if constexpr(i_res < 2)
b_blockwise_copy.RunRead(
b_grid_desc, b_grid_buf, Number<(i_res + 1) % 3>{});
}
});
}
// tail
else if(i == num_loop - 5)
{
static_for<0, 5, 1>{}([&](auto i_res) {
block_sync_lds();
// GEMM num_loop
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
if constexpr(i_res < 4)
{
block_sync_lds();
if constexpr(i_res < 1)
{
// move to i_res + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
// LDS write i_res
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<(i_res + 1) % 3>{});
if constexpr(i_res < 1)
a_blockwise_copy.RunRead(
a_grid_desc, a_grid_buf, Number<(i_res + 1) % 3>{});
// LDS write i_res
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<(i_res + 1) % 3>{});
if constexpr(i_res < 1)
b_blockwise_copy.RunRead(
b_grid_desc, b_grid_buf, Number<(i_res + 1) % 3>{});
}
});
}
// tail
else if(i == num_loop - 4)
{
static_for<0, 4, 1>{}([&](auto i_res) {
block_sync_lds();
// GEMM num_loop
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
if constexpr(i_res < 3)
{
block_sync_lds();
// LDS write i_res
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<(i_res + 1) % 3>{});
// LDS write i_res
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<(i_res + 1) % 3>{});
}
});
}
}
};
// 4-stage prefetch
template <>
struct GridwiseGemmPipeline_v2<4>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 4 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop / 4 > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
static_for<0, 4, 1>{}([&](auto i_pre) {
// global read i_pre
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_pre>{});
s_nop();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_pre>{});
s_nop();
// move to i_pre + 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
});
// Initialize C
c_thread_buf.Clear();
index_t i = 0;
// main body
if constexpr(HasMainLoop)
{
do
{
static_for<0, 4, 1>{}([&](auto i_main) {
// LDS write i_main
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_main>{});
// global Read i_main + 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_main>{});
// LDS write i_main
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_main>{});
// global Read i_main + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_main>{});
// move to i_main + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
s_nop();
// GEMM i_main
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
s_nop();
});
i += 4;
} while(i < (num_loop - 4));
}
// tail
static_for<0, I4, 1>{}([&](auto i_res) {
// Write num_loop - 3
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
s_nop();
// GEMM num_loop - 3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
s_nop();
});
}
};
} // namespace ck
......@@ -85,7 +85,7 @@ template <index_t BlockSize,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v2>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
static constexpr auto I0 = Number<0>{};
......@@ -711,7 +711,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
true,
NumGemmKPrefetchStage>(
a_b_k0_m_k1_grid_desc,
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
......@@ -741,7 +742,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
true,
NumGemmKPrefetchStage>(
b_b_k0_n_k1_grid_desc,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
......
# ckProfiler
set(PROFILER_SOURCES
profiler.cpp
profile_gemm.cpp
profile_gemm_splitk.cpp
profile_gemm_bilinear.cpp
profile_gemm_bias_add_reduce.cpp
profile_gemm_add_add_fastgelu.cpp
profile_gemm_add_multiply.cpp
profile_gemm_add_fastgelu.cpp
profile_gemm_add_relu_add_layernorm.cpp
profile_gemm_fastgelu.cpp
profile_gemm_reduce.cpp
profile_batched_gemm.cpp
profile_batched_gemm_gemm.cpp
profile_batched_gemm_add_relu_gemm_add.cpp
profile_batched_gemm_reduce.cpp
# profile_gemm.cpp
# profile_gemm_splitk.cpp
# profile_gemm_bilinear.cpp
# profile_gemm_bias_add_reduce.cpp
# profile_gemm_add_add_fastgelu.cpp
# profile_gemm_add_multiply.cpp
# profile_gemm_add_fastgelu.cpp
# profile_gemm_add_relu_add_layernorm.cpp
# profile_gemm_fastgelu.cpp
# profile_gemm_reduce.cpp
# profile_batched_gemm.cpp
# profile_batched_gemm_gemm.cpp
# profile_batched_gemm_add_relu_gemm_add.cpp
# profile_batched_gemm_reduce.cpp
profile_grouped_gemm.cpp
profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp
profile_conv_bwd_data.cpp
profile_grouped_conv_fwd.cpp
profile_grouped_conv_bwd_weight.cpp
profile_reduce.cpp
profile_groupnorm.cpp
profile_layernorm.cpp
profile_avg_pool2d_fwd.cpp
profile_max_pool3d_fwd.cpp
profile_softmax.cpp
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_grouped_gemm_fastgelu.cpp
profile_contraction_bilinear.cpp
profile_contraction_scale.cpp
profile_batched_gemm_multi_d.cpp
profile_grouped_conv_bwd_data.cpp
#profile_conv_fwd.cpp
#profile_conv_fwd_bias_relu.cpp
#profile_conv_fwd_bias_relu_add.cpp
#profile_conv_bwd_data.cpp
#profile_grouped_conv_fwd.cpp
#profile_grouped_conv_bwd_weight.cpp
#profile_reduce.cpp
#profile_groupnorm.cpp
#profile_layernorm.cpp
#profile_avg_pool2d_fwd.cpp
#profile_max_pool3d_fwd.cpp
#profile_softmax.cpp
#profile_batchnorm_fwd.cpp
#profile_batchnorm_bwd.cpp
#profile_batchnorm_infer.cpp
#profile_grouped_gemm_fastgelu.cpp
#profile_contraction_bilinear.cpp
#profile_contraction_scale.cpp
#profile_batched_gemm_multi_d.cpp
#profile_grouped_conv_bwd_data.cpp
)
set(PROFILER_EXECUTABLE ckProfiler)
......@@ -44,42 +44,42 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
......@@ -12,7 +12,7 @@ cmake
-save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
-D GPU_TARGETS="gfx90a" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment