Commit b134b7d6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents 090ba885 9f71ff48
...@@ -747,7 +747,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -747,7 +747,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -775,13 +775,12 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -775,13 +775,12 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
...@@ -796,8 +795,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -796,8 +795,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -827,8 +826,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -827,8 +826,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -847,9 +846,10 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -847,9 +846,10 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
...@@ -862,17 +862,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -862,17 +862,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// Input tensors can't be bigger than 2GB each. // Input tensors can't be bigger than 2GB each.
constexpr std::size_t GB2 = 2 * 1e9; constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() > GB2) if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 ||
{ arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 ||
return false; arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) > GB2)
}
if(arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() > GB2)
{
return false;
}
if(arg.c_grid_desc_m_n_.GetElementSpaceSize() > GB2)
{ {
return false; return false;
} }
......
...@@ -9,8 +9,7 @@ namespace device { ...@@ -9,8 +9,7 @@ namespace device {
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D1ElementwiseOperation>
typename D1ReduceOperation>
struct DeviceGemmReduce : public BaseOperator struct DeviceGemmReduce : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -27,8 +26,7 @@ struct DeviceGemmReduce : public BaseOperator ...@@ -27,8 +26,7 @@ struct DeviceGemmReduce : public BaseOperator
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op,
D1ReduceOperation d1_reduce_op,
ck::index_t BatchCount = 1) = 0; ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
...@@ -37,13 +35,11 @@ struct DeviceGemmReduce : public BaseOperator ...@@ -37,13 +35,11 @@ struct DeviceGemmReduce : public BaseOperator
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D1ElementwiseOperation>
typename D1ReduceOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation, using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation>>;
D1ReduceOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -14,6 +14,9 @@ namespace ck { ...@@ -14,6 +14,9 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -29,6 +32,7 @@ template <typename ALayout, ...@@ -29,6 +32,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
typename D1ElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -61,12 +65,12 @@ template <typename ALayout, ...@@ -61,12 +65,12 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock> index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation, struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation>
D1ReduceOperation>
{ {
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
...@@ -382,6 +386,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -382,6 +386,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D0ReduceOperation,
D1ReduceOperation, D1ReduceOperation,
D1ElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
...@@ -421,7 +426,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -421,7 +426,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>; CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopSched>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -440,8 +446,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -440,8 +446,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op)
D1ReduceOperation d1_reduce_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -457,8 +462,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -457,8 +462,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
d0_reduce_op_{d0_reduce_op}, d1_element_op_{d1_element_op}
d1_reduce_op_{d1_reduce_op}
{ {
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
...@@ -491,8 +495,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -491,8 +495,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
D0ReduceOperation d0_reduce_op_; D1ElementwiseOperation d1_element_op_;
D1ReduceOperation d1_reduce_op_;
}; };
// Invoker // Invoker
...@@ -500,7 +503,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -500,7 +503,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int /* nrepeat */ = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -530,11 +533,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -530,11 +533,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); float elapsed_time = 0.0f;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
...@@ -544,8 +547,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -544,8 +547,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation,
D1ReduceOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -553,25 +555,26 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -553,25 +555,26 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
launch_kernel(kernel, elapsed_time =
dim3(grid_size), launch_and_time_kernel(stream_config,
dim3(BlockSize), kernel,
0, dim3(grid_size),
arg.p_a_grid_, dim3(BlockSize),
arg.p_b_grid_, 0,
arg.p_c_grid_, arg.p_a_grid_,
arg.p_d0_grid_, arg.p_b_grid_,
arg.p_d1_grid_, arg.p_c_grid_,
arg.a_element_op_, arg.p_d0_grid_,
arg.b_element_op_, arg.p_d1_grid_,
arg.c_element_op_, arg.a_element_op_,
arg.d0_reduce_op_, arg.b_element_op_,
arg.d1_reduce_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.d1_element_op_,
arg.b_grid_desc_bk0_n_bk1_, arg.a_grid_desc_ak0_m_ak1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.b_grid_desc_bk0_n_bk1_,
arg.d_grid_desc_mblock_mperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
} }
else else
{ {
...@@ -583,8 +586,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -583,8 +586,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation,
D1ReduceOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -592,34 +594,36 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -592,34 +594,36 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
launch_kernel(kernel, elapsed_time =
dim3(grid_size), launch_and_time_kernel(stream_config,
dim3(BlockSize), kernel,
0, dim3(grid_size),
arg.p_a_grid_, dim3(BlockSize),
arg.p_b_grid_, 0,
arg.p_c_grid_, arg.p_a_grid_,
arg.p_d0_grid_, arg.p_b_grid_,
arg.p_d1_grid_, arg.p_c_grid_,
arg.a_element_op_, arg.p_d0_grid_,
arg.b_element_op_, arg.p_d1_grid_,
arg.c_element_op_, arg.a_element_op_,
arg.d0_reduce_op_, arg.b_element_op_,
arg.d1_reduce_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.d1_element_op_,
arg.b_grid_desc_bk0_n_bk1_, arg.a_grid_desc_ak0_m_ak1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.b_grid_desc_bk0_n_bk1_,
arg.d_grid_desc_mblock_mperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
} }
return 0; return elapsed_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
...@@ -655,8 +659,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -655,8 +659,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op)
D1ReduceOperation d1_reduce_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -672,8 +675,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -672,8 +675,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op};
d1_reduce_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -693,8 +695,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -693,8 +695,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op,
D1ReduceOperation d1_reduce_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
...@@ -711,8 +712,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -711,8 +712,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op);
d1_reduce_op);
} }
// polymorphic // polymorphic
......
...@@ -290,8 +290,9 @@ struct DeviceGemmXdl ...@@ -290,8 +290,9 @@ struct DeviceGemmXdl
{ {
using Argument = DeviceGemmXdl::Argument; using Argument = DeviceGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...@@ -304,6 +305,7 @@ struct DeviceGemmXdl ...@@ -304,6 +305,7 @@ struct DeviceGemmXdl
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
...@@ -317,13 +319,12 @@ struct DeviceGemmXdl ...@@ -317,13 +319,12 @@ struct DeviceGemmXdl
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
...@@ -338,8 +339,8 @@ struct DeviceGemmXdl ...@@ -338,8 +339,8 @@ struct DeviceGemmXdl
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -369,8 +370,8 @@ struct DeviceGemmXdl ...@@ -369,8 +370,8 @@ struct DeviceGemmXdl
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -390,9 +391,10 @@ struct DeviceGemmXdl ...@@ -390,9 +391,10 @@ struct DeviceGemmXdl
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_HPP
#define DEVICE_GEMM_XDL_C_SHUFFLE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_gemm.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v3r1.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename CShuffleDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumPrefetch = 1>
struct DeviceGemmXdl_C_Shuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % AK1 == 0);
const index_t K0 = K / AK1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, AK1)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_k0_m_k1;
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % BK1 == 0);
const index_t K0 = K / BK1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, BK1)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_k0_n_k1;
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
CBlockTransferScalarPerVector_NWaveNPerXdl,
NumPrefetch>;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ =
DeviceGemmXdl_C_Shuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceGemmXdl_C_Shuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl_C_Shuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmXdl_C_Shuffle::Argument;
float Run(const Argument& arg, int nrepeat = 1)
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_v3r1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl_C_Shuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl_C_Shuffle::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel = kernel_gemm_xdlops_v3r1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl_C_Shuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl_C_Shuffle::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdl_C_Shuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP #pragma once
#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp" #include "device.hpp"
...@@ -266,7 +264,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -266,7 +264,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
{ {
using Argument = DeviceGemmXdl_C_Shuffle_Bias_2d::Argument; using Argument = DeviceGemmXdl_C_Shuffle_Bias_2d::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -291,18 +289,17 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -291,18 +289,17 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
arg.N01_)) arg.N01_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 has invalid setting");
} }
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v3r2< const auto kernel = kernel_gemm_xdlops_v3r2<
GridwiseGemm, GridwiseGemm,
...@@ -323,8 +320,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -323,8 +320,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -362,8 +359,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -362,8 +359,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -385,9 +382,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -385,9 +382,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
...@@ -505,4 +503,3 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -505,4 +503,3 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -273,7 +273,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -273,7 +273,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -303,13 +303,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -303,13 +303,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v3r2< const auto kernel = kernel_gemm_xdlops_v3r2<
GridwiseGemm, GridwiseGemm,
...@@ -330,8 +329,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -330,8 +329,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -369,8 +368,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -369,8 +368,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -392,9 +391,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -392,9 +391,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -312,7 +312,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -312,7 +312,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -345,13 +345,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -345,13 +345,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v3r3< const auto kernel = kernel_gemm_xdlops_v3r3<
GridwiseGemm, GridwiseGemm,
...@@ -375,8 +374,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -375,8 +374,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -419,8 +418,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -419,8 +418,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -444,9 +443,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -444,9 +443,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -14,6 +14,9 @@ namespace ck { ...@@ -14,6 +14,9 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -54,7 +57,8 @@ template <typename ALayout, ...@@ -54,7 +57,8 @@ template <typename ALayout,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock> index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemm_Xdl_CShuffle struct DeviceGemm_Xdl_CShuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
...@@ -375,7 +379,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -375,7 +379,8 @@ struct DeviceGemm_Xdl_CShuffle
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>; CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -435,7 +440,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -435,7 +440,7 @@ struct DeviceGemm_Xdl_CShuffle
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -462,13 +467,12 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -462,13 +467,12 @@ struct DeviceGemm_Xdl_CShuffle
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
...@@ -483,42 +487,22 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -483,42 +487,22 @@ struct DeviceGemm_Xdl_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
if(nrepeat == 0) ave_time =
{ launch_and_time_kernel(stream_config,
launch_kernel(kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
}
else
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
} }
else else
{ {
...@@ -534,52 +518,32 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -534,52 +518,32 @@ struct DeviceGemm_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
ave_time =
if(nrepeat == 0) launch_and_time_kernel(stream_config,
{ kernel,
launch_kernel(kernel, dim3(grid_size),
dim3(grid_size), dim3(BlockSize),
dim3(BlockSize), 0,
0, arg.p_a_grid_,
arg.p_a_grid_, arg.p_b_grid_,
arg.p_b_grid_, arg.p_c_grid_,
arg.p_c_grid_, arg.a_element_op_,
arg.a_element_op_, arg.b_element_op_,
arg.b_element_op_, arg.c_element_op_,
arg.c_element_op_, arg.a_grid_desc_ak0_m_ak1_,
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_ctile_map_);
arg.block_2_ctile_map_);
}
else
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
} }
return ave_time; return ave_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -385,8 +385,11 @@ struct DeviceGemmXdlSplitK ...@@ -385,8 +385,11 @@ struct DeviceGemmXdlSplitK
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
ShowInfo(arg);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
...@@ -408,50 +411,30 @@ struct DeviceGemmXdlSplitK ...@@ -408,50 +411,30 @@ struct DeviceGemmXdlSplitK
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(nrepeat > 0) // FIXME: this should be moved outside of DeviceOp
{ hipGetErrorString(
ShowInfo(arg); hipMemset(arg.p_c_grid_,
ave_time = launch_and_time_kernel(kernel, 0,
nrepeat, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
dim3(grid_size), sizeof(CDataType)));
dim3(BlockSize),
0, ave_time = launch_and_time_kernel(stream_config,
arg.p_a_grid_, kernel,
arg.p_b_grid_, dim3(grid_size),
arg.p_c_grid_, dim3(BlockSize),
arg.a_grid_desc_kbatch_k0_m_k1_, 0,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.p_a_grid_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.p_b_grid_,
arg.a_element_op_, arg.p_c_grid_,
arg.b_element_op_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.c_element_op_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.block_2_ctile_map_); arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
} arg.a_element_op_,
arg.b_element_op_,
if(kbatch > 1 || nrepeat <= 0) arg.c_element_op_,
{ arg.block_2_ctile_map_);
hipGetErrorString(
hipMemset(arg.p_c_grid_,
0,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
if(kbatch == 1) if(kbatch == 1)
...@@ -531,9 +514,10 @@ struct DeviceGemmXdlSplitK ...@@ -531,9 +514,10 @@ struct DeviceGemmXdlSplitK
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -391,8 +391,11 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -391,8 +391,11 @@ struct DeviceGemmXdlSplitKCShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
ShowInfo(arg);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
...@@ -414,51 +417,29 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -414,51 +417,29 @@ struct DeviceGemmXdlSplitKCShuffle
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(nrepeat > 0) hipGetErrorString(hipMemset(
{ arg.p_c_grid_,
ShowInfo(arg); 0,
ave_time = arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
launch_and_time_kernel(kernel, sizeof(CDataType)));
nrepeat,
dim3(grid_size), launch_and_time_kernel(stream_config,
dim3(BlockSize), kernel,
0, dim3(grid_size),
arg.p_a_grid_, dim3(BlockSize),
arg.p_b_grid_, 0,
arg.p_c_grid_, arg.p_a_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_, arg.p_b_grid_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.p_c_grid_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.a_element_op_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.b_element_op_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c_element_op_, arg.a_element_op_,
arg.block_2_ctile_map_); arg.b_element_op_,
} arg.c_element_op_,
arg.block_2_ctile_map_);
if(kbatch > 1 || nrepeat <= 0)
{
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
if(kbatch == 1) if(kbatch == 1)
...@@ -542,9 +523,10 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -542,9 +523,10 @@ struct DeviceGemmXdlSplitKCShuffle
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -17,6 +17,88 @@ namespace ck { ...@@ -17,6 +17,88 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_descs[group_id].a_ptr,
gemm_descs[group_id].b_ptr,
gemm_descs[group_id].c_ptr,
p_shared,
gemm_descs[group_id].a_grid_desc_k0_m_k1_,
gemm_descs[group_id].b_grid_desc_k0_n_k1_,
gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_descs[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
#else
ignore = gemm_descs;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
...@@ -290,17 +372,18 @@ struct DeviceGroupedGemmXdl ...@@ -290,17 +372,18 @@ struct DeviceGroupedGemmXdl
{ {
grid_size_ = 0; grid_size_ = 0;
group_count_ = static_cast<int>(gemm_shapes.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
if(!(group_count_ == p_a.size() && group_count_ == p_b.size() && if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
group_count_ == p_c.size())) group_count_ == ck::type_convert<ck::index_t>(p_b.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_c.size())))
{ {
throw std::runtime_error("wrong! group_count_ != P_a/b/c.size"); throw std::runtime_error("wrong! group_count_ != P_a/b/c.size");
} }
gemm_desc_kernel_arg_.reserve(group_count_); gemm_desc_kernel_arg_.reserve(group_count_);
for(index_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_shapes.size(); i++)
{ {
const index_t M = gemm_shapes[i].M; const index_t M = gemm_shapes[i].M;
const index_t N = gemm_shapes[i].N; const index_t N = gemm_shapes[i].N;
...@@ -366,61 +449,55 @@ struct DeviceGroupedGemmXdl ...@@ -366,61 +449,55 @@ struct DeviceGroupedGemmXdl
{ {
using Argument = DeviceGroupedGemmXdl::Argument; using Argument = DeviceGroupedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_arg_arg; StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
bool has_main_k0_block_loop = true; bool has_main_k_block_loop = true;
static_for<0, MaxGroupCount, 1>{}([&](auto i) { static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < arg.gemm_desc_kernel_arg_.size()) if(i < arg.gemm_desc_kernel_arg_.size())
{ {
gemm_desc_kernel_arg_arg(i) = arg.gemm_desc_kernel_arg_[i]; gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i];
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0) << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< ", " << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I1) << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
<< ", "
<< gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I2)
<< "}";
std::cout << ", arg.b_grid_desc_k0_n_k1_{" std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I0) << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< ", " << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I1) << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
<< ", "
<< gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I2)
<< "}";
std::cout << ", arg.c_grid_desc_m_n_{ " std::cout << ", arg.c_grid_desc_m_n_{ "
<< gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I0) << ", " << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I1) << "}" << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl; << std::endl;
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_, gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_, gemm_desc_kernel_args[i].c_grid_desc_m_n_,
gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_, arg.M01_,
arg.M01_, arg.N01_))
arg.N01_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
} }
const auto K0 = gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0); const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{ {
throw std::runtime_error("wrong! not all gemm has_main_k0_block_loop"); throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
} }
} }
}); });
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(has_main_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm, kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
...@@ -433,12 +510,12 @@ struct DeviceGroupedGemmXdl ...@@ -433,12 +510,12 @@ struct DeviceGroupedGemmXdl
true, true,
MaxGroupCount>; MaxGroupCount>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
gemm_desc_kernel_arg_arg, gemm_desc_kernel_args,
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -457,12 +534,12 @@ struct DeviceGroupedGemmXdl ...@@ -457,12 +534,12 @@ struct DeviceGroupedGemmXdl
false, false,
MaxGroupCount>; MaxGroupCount>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
gemm_desc_kernel_arg_arg, gemm_desc_kernel_args,
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -473,9 +550,10 @@ struct DeviceGroupedGemmXdl ...@@ -473,9 +550,10 @@ struct DeviceGroupedGemmXdl
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
...@@ -487,7 +565,7 @@ struct DeviceGroupedGemmXdl ...@@ -487,7 +565,7 @@ struct DeviceGroupedGemmXdl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(arg.gemm_desc_kernel_arg_.size() != arg.group_count_) if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
return false; return false;
else else
return true; return true;
......
...@@ -204,7 +204,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -204,7 +204,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType, using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType, OutDataType,
...@@ -241,8 +241,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -241,8 +241,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize); const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
return launch_and_time_kernel(kernel, return launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -257,9 +257,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -257,9 +257,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
arg.p_out_indices_dev_); arg.p_out_indices_dev_);
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -211,7 +211,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -211,7 +211,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = const auto in_grid_desc_m_k =
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
...@@ -253,8 +253,8 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -253,8 +253,8 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation>; AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel, avg_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(arg.gridSize), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -272,9 +272,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -272,9 +272,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}; };
}; };
......
...@@ -182,7 +182,7 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -182,7 +182,7 @@ struct DeviceReduceBlockWiseSecondCall
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_); arg.inLengths_, arg.inStrides_);
...@@ -224,8 +224,8 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -224,8 +224,8 @@ struct DeviceReduceBlockWiseSecondCall
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation>; AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel, avg_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(arg.gridSize), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -243,10 +243,11 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -243,10 +243,11 @@ struct DeviceReduceBlockWiseSecondCall
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}; }
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
......
...@@ -245,7 +245,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -245,7 +245,7 @@ struct DeviceReduceMultiBlockAtomicAdd
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
...@@ -275,8 +275,6 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -275,8 +275,6 @@ struct DeviceReduceMultiBlockAtomicAdd
float avg_time = 0; float avg_time = 0;
KernelTimer timer;
const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>; const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>;
const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce, const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce,
InDataType, InDataType,
...@@ -287,50 +285,38 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -287,50 +285,38 @@ struct DeviceReduceMultiBlockAtomicAdd
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation>; AccElementwiseOperation>;
printf("launch_and_time_kernel: grid_dim {%ld, 1, 1}, block_dim {%d, 1, 1} \n", avg_time += launch_and_time_kernel(stream_config,
arg.gridSize, kernel_pre,
BlockSize); dim3(arg.gridSize_pre),
printf("Warm up\n"); dim3(BlockSize),
0,
for(int i = 0; i < nrepeat + 1; i++) out_grid_desc_m,
{ arg.out_dev_,
if(i == 1) static_cast<OutDataType>(0.0f));
timer.Start();
avg_time += launch_and_time_kernel(stream_config,
launch_kernel(kernel_pre, kernel_main,
dim3(arg.gridSize_pre), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
out_grid_desc_m, in_grid_desc_m_k,
arg.out_dev_, out_grid_desc_m,
static_cast<OutDataType>(0.0f)); arg.in_elementwise_op_,
arg.acc_elementwise_op_,
launch_kernel(kernel_main, arg.blkGroupSize,
dim3(arg.gridSize), arg.kBlockTileIterations,
dim3(BlockSize), arg.alpha_,
0, arg.in_dev_,
in_grid_desc_m_k, arg.out_dev_);
out_grid_desc_m,
arg.in_elementwise_op_, return avg_time;
arg.acc_elementwise_op_, }
arg.blkGroupSize,
arg.kBlockTileIterations,
arg.alpha_,
arg.in_dev_,
arg.out_dev_);
};
timer.End();
avg_time = timer.GetElapsedTime() / nrepeat;
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}; }
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
......
...@@ -273,7 +273,7 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -273,7 +273,7 @@ struct DeviceReduceMultiBlockPartialReduce
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
...@@ -313,8 +313,8 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -313,8 +313,8 @@ struct DeviceReduceMultiBlockPartialReduce
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation>; AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel, avg_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(arg.gridSize), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -331,10 +331,11 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -331,10 +331,11 @@ struct DeviceReduceMultiBlockPartialReduce
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}; }
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
......
...@@ -212,7 +212,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -212,7 +212,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = const auto in_grid_desc_m_k =
DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
...@@ -254,8 +254,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -254,8 +254,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
InElementwiseOperation, InElementwiseOperation,
OutElementwiseOperation>; OutElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel, avg_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(arg.gridSize), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -272,10 +272,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -272,10 +272,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}; }
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
......
...@@ -143,37 +143,6 @@ struct AddHardswishAdd ...@@ -143,37 +143,6 @@ struct AddHardswishAdd
} }
}; };
struct RequantReluRequant
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
__host__ __device__ constexpr void operator()(int8_t& y, const int& x) const
{
float gemm_requant = scaleGemm_ * static_cast<float>(x);
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<int8_t>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
// for reference_gemm
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<float>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
float scaleGemm_;
float scaleRelu_;
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the // Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 // elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
......
...@@ -5,20 +5,6 @@ namespace ck { ...@@ -5,20 +5,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
struct ReduceSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v; }
};
struct ReduceSquareSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v * v; }
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment