"script/compile-rocm3.1.sh" did not exist on "f58bf38445c4b28c88a4bbe5bcb3f12e474650c8"
Commit e00a943e authored by myamlak's avatar myamlak
Browse files

Merge remote-tracking branch 'origin/develop' into myamlak/cgemm

parents ffe12e2e 9f71ff48
......@@ -385,8 +385,11 @@ struct DeviceGemmXdlSplitK
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ShowInfo(arg);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
......@@ -408,35 +411,15 @@ struct DeviceGemmXdlSplitK
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(nrepeat > 0)
{
ShowInfo(arg);
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
if(kbatch > 1 || nrepeat <= 0)
{
// FIXME: this should be moved outside of DeviceOp
hipGetErrorString(
hipMemset(arg.p_c_grid_,
0,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -450,8 +433,8 @@ struct DeviceGemmXdlSplitK
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
};
if(has_main_k0_block_loop)
{
if(kbatch == 1)
......@@ -531,9 +514,10 @@ struct DeviceGemmXdlSplitK
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -391,8 +391,11 @@ struct DeviceGemmXdlSplitKCShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ShowInfo(arg);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
......@@ -414,36 +417,14 @@ struct DeviceGemmXdlSplitKCShuffle
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(nrepeat > 0)
{
ShowInfo(arg);
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
if(kbatch > 1 || nrepeat <= 0)
{
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -457,8 +438,8 @@ struct DeviceGemmXdlSplitKCShuffle
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
};
if(has_main_k0_block_loop)
{
if(kbatch == 1)
......@@ -542,9 +523,10 @@ struct DeviceGemmXdlSplitKCShuffle
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -449,7 +449,7 @@ struct DeviceGroupedGemmXdl
{
using Argument = DeviceGroupedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
......@@ -510,8 +510,8 @@ struct DeviceGroupedGemmXdl
true,
MaxGroupCount>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
......@@ -534,8 +534,8 @@ struct DeviceGroupedGemmXdl
false,
MaxGroupCount>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
......@@ -550,9 +550,10 @@ struct DeviceGroupedGemmXdl
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -204,7 +204,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
......@@ -241,8 +241,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
return launch_and_time_kernel(kernel,
nrepeat,
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -257,9 +257,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
arg.p_out_indices_dev_);
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -211,7 +211,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k =
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
......@@ -253,8 +253,8 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
......@@ -272,9 +272,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
......
......@@ -182,7 +182,7 @@ struct DeviceReduceBlockWiseSecondCall
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_);
......@@ -224,8 +224,8 @@ struct DeviceReduceBlockWiseSecondCall
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
......@@ -243,10 +243,11 @@ struct DeviceReduceBlockWiseSecondCall
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
......@@ -245,7 +245,7 @@ struct DeviceReduceMultiBlockAtomicAdd
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
......@@ -275,8 +275,6 @@ struct DeviceReduceMultiBlockAtomicAdd
float avg_time = 0;
KernelTimer timer;
const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>;
const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce,
InDataType,
......@@ -287,17 +285,8 @@ struct DeviceReduceMultiBlockAtomicAdd
InElementwiseOperation,
AccElementwiseOperation>;
printf("launch_and_time_kernel: grid_dim {%ld, 1, 1}, block_dim {%d, 1, 1} \n",
arg.gridSize,
BlockSize);
printf("Warm up\n");
for(int i = 0; i < nrepeat + 1; i++)
{
if(i == 1)
timer.Start();
launch_kernel(kernel_pre,
avg_time += launch_and_time_kernel(stream_config,
kernel_pre,
dim3(arg.gridSize_pre),
dim3(BlockSize),
0,
......@@ -305,7 +294,8 @@ struct DeviceReduceMultiBlockAtomicAdd
arg.out_dev_,
static_cast<OutDataType>(0.0f));
launch_kernel(kernel_main,
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
......@@ -318,19 +308,15 @@ struct DeviceReduceMultiBlockAtomicAdd
arg.alpha_,
arg.in_dev_,
arg.out_dev_);
};
timer.End();
avg_time = timer.GetElapsedTime() / nrepeat;
return (avg_time);
};
return avg_time;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
......@@ -273,7 +273,7 @@ struct DeviceReduceMultiBlockPartialReduce
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
......@@ -313,8 +313,8 @@ struct DeviceReduceMultiBlockPartialReduce
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
......@@ -331,10 +331,11 @@ struct DeviceReduceMultiBlockPartialReduce
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
......@@ -212,7 +212,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k =
DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
......@@ -254,8 +254,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
InElementwiseOperation,
OutElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
......@@ -272,10 +272,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
#pragma once
#include "common_header.hpp"
#include "tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
namespace ck {
......@@ -248,4 +249,116 @@ struct GridwiseGemmPipeline_v1<2>
}
};
template <index_t NumPrefetch>
struct GridwiseGemmPipelineInterwave_v1;
template <>
struct GridwiseGemmPipelineInterwave_v1<1>
{
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// block_sync_lds(); // moved into blockwise_gemm
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template <>
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
{
};
template <index_t NumPrefetch, LoopScheduler LoopSched>
constexpr auto GridwiseGemmPipeline_v1_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return GridwiseGemmPipeline_v1<NumPrefetch>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
return GridwiseGemmPipelineInterwave_v1<NumPrefetch>{};
}
}
} // namespace ck
......@@ -134,7 +134,8 @@ template <typename FloatAB,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched>
struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -473,8 +474,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
......@@ -483,7 +484,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>{};
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......@@ -502,11 +504,14 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
......
......@@ -107,7 +107,8 @@ template <typename FloatAB,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -416,8 +417,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
......@@ -426,7 +427,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>{};
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......@@ -445,11 +447,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
......
#pragma once
#include <memory>
#include <string>
#include "stream_config.hpp"
#include "config.hpp"
#include "device_base.hpp"
struct DeviceConvFwdPtr_t
{
using BaseArgument = ck::tensor_operation::device::BaseArgument;
using BaseInvoker = ck::tensor_operation::device::BaseInvoker;
struct DeviceConvFwdPtrImpl;
std::unique_ptr<DeviceConvFwdPtrImpl> pImpl;
DeviceConvFwdPtr_t();
~DeviceConvFwdPtr_t();
DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&);
DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&);
DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete;
DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&) = delete;
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* in_ptr,
void* wei_ptr,
void* out_ptr,
size_t N,
size_t K,
size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
const; // in,wei and out element ops are ignored for now since even if we change them, they
// cant be linked
std::unique_ptr<BaseInvoker>
MakeInvokerPointer() const; // requires including BaseInvoker headers
std::string GetTypeString();
bool IsSupportedArgument(const BaseArgument* arg_ptr);
};
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
#ifndef DEVICE_HPP
#define DEVICE_HPP
#pragma once
#include <memory>
#include <functional>
#include <thread>
#include <chrono>
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "stream_config.hpp"
#include "ck/options.hpp"
inline void hip_check_error(hipError_t x)
{
if(x != hipSuccess)
{
std::ostringstream ss;
ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__
<< "in function: " << __func__;
throw std::runtime_error(ss.str());
}
}
struct DeviceMem
{
......@@ -36,22 +49,17 @@ struct KernelTimer
std::unique_ptr<KernelTimerImpl> impl;
};
using device_stream_t = hipStream_t;
template <typename... Args, typename F>
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
hipStream_t stream_id = nullptr;
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
}
template <typename... Args, typename F>
float launch_and_time_kernel(
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
float launch_and_time_kernel(const StreamConfig& stream_config,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
KernelTimer timer;
#if CK_TIME_KERNEL
if(stream_config.time_kernel_)
{
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
......@@ -61,24 +69,39 @@ float launch_and_time_kernel(
block_dim.y,
block_dim.z);
printf("Warm up\n");
const int nrepeat = 10;
hipStream_t stream_id = nullptr;
printf("Warm up 1 time\n");
// warm up
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
hipLaunchKernelGGL(
kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
printf("Start running %d times...\n", nrepeat);
KernelTimer timer;
timer.Start();
for(int i = 0; i < nrepeat; ++i)
{
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
hipLaunchKernelGGL(
kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
}
timer.End();
return timer.GetElapsedTime() / nrepeat;
}
}
else
{
hipLaunchKernelGGL(
kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
return 0;
}
#else
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
return 0;
#endif
}
......@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
......@@ -135,7 +135,8 @@ struct ReferenceCGemm : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
......@@ -121,7 +121,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
......@@ -291,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
#ifndef REFERENCE_CONV_FWD_HPP
#define REFERENCE_CONV_FWD_HPP
#pragma once
#include <iostream>
#include <type_traits>
#include <sstream>
#include "stream_config.hpp"
#include "device_base.hpp"
#include "host_tensor.hpp"
......@@ -251,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......@@ -311,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -124,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment