Commit e00a943e authored by myamlak's avatar myamlak
Browse files

Merge remote-tracking branch 'origin/develop' into myamlak/cgemm

parents ffe12e2e 9f71ff48
#ifndef DEVICE_BASE_HPP #pragma once
#define DEVICE_BASE_HPP
#include <string> #include <string>
#include "stream_config.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -22,7 +23,10 @@ struct BaseInvoker ...@@ -22,7 +23,10 @@ struct BaseInvoker
BaseInvoker(const BaseInvoker&) = default; BaseInvoker(const BaseInvoker&) = default;
BaseInvoker& operator=(const BaseInvoker&) = default; BaseInvoker& operator=(const BaseInvoker&) = default;
virtual float Run(const BaseArgument*, int = 1) = 0; virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
{
return float{0};
}
virtual ~BaseInvoker() {} virtual ~BaseInvoker() {}
}; };
...@@ -33,8 +37,8 @@ struct BaseOperator ...@@ -33,8 +37,8 @@ struct BaseOperator
BaseOperator(const BaseOperator&) = default; BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default;
virtual bool IsSupportedArgument(const BaseArgument*) = 0; virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const = 0; virtual std::string GetTypeString() const { return ""; }
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
...@@ -42,4 +46,3 @@ struct BaseOperator ...@@ -42,4 +46,3 @@ struct BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -106,6 +106,9 @@ __global__ void ...@@ -106,6 +106,9 @@ __global__ void
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
} }
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -154,7 +157,8 @@ template <typename ALayout, ...@@ -154,7 +157,8 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock> index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation, struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -600,7 +604,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -600,7 +604,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>; CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopSched>;
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1)); using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1));
...@@ -688,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -688,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int /* nrepeat */ = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -724,6 +729,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -724,6 +729,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float elapsed_time = 0.0f;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1<
...@@ -743,7 +749,9 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -743,7 +749,9 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
launch_kernel(kernel, elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -783,7 +791,9 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -783,7 +791,9 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
launch_kernel(kernel, elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -805,13 +815,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -805,13 +815,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
return 0; return elapsed_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -428,7 +428,7 @@ struct DeviceBatchedGemmXdl ...@@ -428,7 +428,7 @@ struct DeviceBatchedGemmXdl
{ {
using Argument = DeviceBatchedGemmXdl::Argument; using Argument = DeviceBatchedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl ...@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl ...@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl ...@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -55,7 +55,8 @@ template <typename ALayout, ...@@ -55,7 +55,8 @@ template <typename ALayout,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock> index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceCGemm_4Gemm_Xdl_CShuffle struct DeviceCGemm_4Gemm_Xdl_CShuffle
: public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
...@@ -376,7 +377,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -376,7 +377,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>; CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -448,7 +450,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -448,7 +450,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_)) arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
...@@ -478,77 +480,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -478,77 +480,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_real_,
arg.p_b_grid_real_,
arg.p_c_grid_real_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_imag_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// c_real = c_real - aux needed here!!!
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_real_,
arg.p_b_grid_imag_,
arg.p_c_grid_imag_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_real_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// c_imag = c_imag + aux needed here!!!
}
else
{
ave_time += ave_time +=
launch_and_time_kernel(kernel, launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -564,8 +498,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -564,8 +498,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
ave_time += ave_time +=
launch_and_time_kernel(kernel, launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -580,11 +514,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -580,11 +514,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// // c_real = c_real - aux needed here!!! // c_real = c_real - aux needed here!!!
ave_time += ave_time +=
launch_and_time_kernel(kernel, launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -600,8 +534,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -600,8 +534,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
ave_time += ave_time +=
launch_and_time_kernel(kernel, launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -618,7 +552,6 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -618,7 +552,6 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_imag = c_imag + aux needed here!!! // c_imag = c_imag + aux needed here!!!
} }
}
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdl_cshuffle_v1<
...@@ -634,77 +567,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -634,77 +567,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_real_,
arg.p_b_grid_real_,
arg.p_c_grid_real_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_imag_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// // c_real = c_real - aux needed here!!!
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_real_,
arg.p_b_grid_imag_,
arg.p_c_grid_imag_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_real_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// c_imag = c_imag + aux needed here!!!
}
else
{
ave_time += ave_time +=
launch_and_time_kernel(kernel, launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -720,8 +585,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -720,8 +585,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
ave_time += ave_time +=
launch_and_time_kernel(kernel, launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -736,11 +601,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -736,11 +601,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_real = c_real - aux needed here!!! // // c_real = c_real - aux needed here!!!
ave_time += ave_time +=
launch_and_time_kernel(kernel, launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -756,8 +621,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -756,8 +621,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
ave_time += ave_time +=
launch_and_time_kernel(kernel, launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -774,15 +639,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -774,15 +639,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_imag = c_imag + aux needed here!!! // c_imag = c_imag + aux needed here!!!
} }
}
return ave_time; return ave_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
ShowInfo(arg); ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -437,35 +438,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -437,35 +438,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(nrepeat > 0)
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
if(kbatch > 1 || nrepeat <= 0)
{
hipGetErrorString(hipMemset( hipGetErrorString(hipMemset(
arg.p_c_grid_, arg.p_c_grid_,
0, 0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType))); sizeof(CDataType)));
launch_kernel(kernel, launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -479,7 +459,6 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -479,7 +459,6 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
}
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
...@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
...@@ -602,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -602,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
true>; true>;
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -635,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -635,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
false>; false>;
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -655,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -655,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -642,7 +642,7 @@ struct ...@@ -642,7 +642,7 @@ struct
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -727,8 +727,8 @@ struct ...@@ -727,8 +727,8 @@ struct
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -771,8 +771,8 @@ struct ...@@ -771,8 +771,8 @@ struct
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -795,9 +795,10 @@ struct ...@@ -795,9 +795,10 @@ struct
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -605,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -605,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -684,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -684,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -723,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -723,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -745,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -745,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -663,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -663,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -697,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -697,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -717,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -717,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -498,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -498,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -529,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -529,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -549,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -549,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "conv_fwd_util.hpp" #include "conv_util.hpp"
#include "device.hpp" #include "device.hpp"
#include "device_conv_fwd.hpp" #include "device_conv_fwd.hpp"
#include "common_header.hpp" #include "common_header.hpp"
...@@ -92,7 +92,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W ...@@ -92,7 +92,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto naive_conv3d_fwd = const auto naive_conv3d_fwd =
ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType, ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
...@@ -103,8 +103,8 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W ...@@ -103,8 +103,8 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation>; OutElementwiseOperation>;
float ave_time = launch_and_time_kernel(naive_conv3d_fwd, float ave_time = launch_and_time_kernel(stream_config,
nrepeat, naive_conv3d_fwd,
dim3(256), dim3(256),
dim3(256), dim3(256),
0, 0,
...@@ -137,9 +137,10 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W ...@@ -137,9 +137,10 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -438,7 +438,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -438,7 +438,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
...@@ -487,8 +487,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -487,8 +487,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -522,8 +522,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -522,8 +522,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -547,9 +547,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -547,9 +547,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -1241,7 +1241,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1241,7 +1241,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
...@@ -1316,8 +1316,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1316,8 +1316,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
true>; true>;
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -1349,8 +1349,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1349,8 +1349,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
false>; false>;
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -1369,9 +1369,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1369,9 +1369,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -747,7 +747,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -747,7 +747,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -795,8 +795,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -795,8 +795,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -826,8 +826,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -826,8 +826,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -846,9 +846,10 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -846,9 +846,10 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -14,6 +14,9 @@ namespace ck { ...@@ -14,6 +14,9 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -62,7 +65,8 @@ template <typename ALayout, ...@@ -62,7 +65,8 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock> index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation, struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -422,7 +426,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -422,7 +426,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>; CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopSched>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -498,7 +503,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -498,7 +503,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int /* nrepeat */ = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -531,6 +536,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -531,6 +536,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float elapsed_time = 0.0f;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
...@@ -549,7 +555,9 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -549,7 +555,9 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
launch_kernel(kernel, elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -586,7 +594,9 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -586,7 +594,9 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
launch_kernel(kernel, elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -606,13 +616,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -606,13 +616,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
return 0; return elapsed_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -290,7 +290,7 @@ struct DeviceGemmXdl ...@@ -290,7 +290,7 @@ struct DeviceGemmXdl
{ {
using Argument = DeviceGemmXdl::Argument; using Argument = DeviceGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -339,8 +339,8 @@ struct DeviceGemmXdl ...@@ -339,8 +339,8 @@ struct DeviceGemmXdl
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -370,8 +370,8 @@ struct DeviceGemmXdl ...@@ -370,8 +370,8 @@ struct DeviceGemmXdl
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -391,9 +391,10 @@ struct DeviceGemmXdl ...@@ -391,9 +391,10 @@ struct DeviceGemmXdl
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -264,7 +264,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -264,7 +264,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
{ {
using Argument = DeviceGemmXdl_C_Shuffle_Bias_2d::Argument; using Argument = DeviceGemmXdl_C_Shuffle_Bias_2d::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -320,8 +320,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -320,8 +320,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -359,8 +359,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -359,8 +359,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -382,9 +382,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -382,9 +382,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -273,7 +273,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -273,7 +273,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -329,8 +329,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -329,8 +329,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -368,8 +368,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -368,8 +368,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -391,9 +391,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -391,9 +391,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -312,7 +312,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -312,7 +312,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -374,8 +374,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -374,8 +374,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -418,8 +418,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -418,8 +418,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -443,9 +443,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -443,9 +443,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -14,6 +14,9 @@ namespace ck { ...@@ -14,6 +14,9 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -54,7 +57,8 @@ template <typename ALayout, ...@@ -54,7 +57,8 @@ template <typename ALayout,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock> index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemm_Xdl_CShuffle struct DeviceGemm_Xdl_CShuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
...@@ -375,7 +379,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -375,7 +379,8 @@ struct DeviceGemm_Xdl_CShuffle
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>; CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -435,7 +440,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -435,7 +440,7 @@ struct DeviceGemm_Xdl_CShuffle
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -482,28 +487,9 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -482,28 +487,9 @@ struct DeviceGemm_Xdl_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
else
{
ave_time = ave_time =
launch_and_time_kernel(kernel, launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -518,7 +504,6 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -518,7 +504,6 @@ struct DeviceGemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
}
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdl_cshuffle_v1<
...@@ -533,29 +518,9 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -533,29 +518,9 @@ struct DeviceGemm_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
else
{
ave_time = ave_time =
launch_and_time_kernel(kernel, launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -570,15 +535,15 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -570,15 +535,15 @@ struct DeviceGemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
}
return ave_time; return ave_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment