Commit ffa70551 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Fix formatting

parent 29e1829f
...@@ -435,7 +435,10 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -435,7 +435,10 @@ struct DeviceGemm_Xdl_CShuffle
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
#if 0 #if 0
{ {
...@@ -483,24 +486,24 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -483,24 +486,24 @@ struct DeviceGemm_Xdl_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
ave_time = ave_time =
launch_and_time_kernel(kernel, launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else else
{ {
...@@ -516,31 +519,34 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -516,31 +519,34 @@ struct DeviceGemm_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
ave_time = ave_time =
launch_and_time_kernel(kernel, launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
} }
......
...@@ -385,7 +385,10 @@ struct DeviceGemmXdlSplitK ...@@ -385,7 +385,10 @@ struct DeviceGemmXdlSplitK
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
...@@ -416,8 +419,8 @@ struct DeviceGemmXdlSplitK ...@@ -416,8 +419,8 @@ struct DeviceGemmXdlSplitK
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -442,7 +445,7 @@ struct DeviceGemmXdlSplitK ...@@ -442,7 +445,7 @@ struct DeviceGemmXdlSplitK
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -534,7 +537,10 @@ struct DeviceGemmXdlSplitK ...@@ -534,7 +537,10 @@ struct DeviceGemmXdlSplitK
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
} }
......
...@@ -391,7 +391,10 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -391,7 +391,10 @@ struct DeviceGemmXdlSplitKCShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
...@@ -423,8 +426,8 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -423,8 +426,8 @@ struct DeviceGemmXdlSplitKCShuffle
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -449,7 +452,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -449,7 +452,7 @@ struct DeviceGemmXdlSplitKCShuffle
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -545,7 +548,10 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -545,7 +548,10 @@ struct DeviceGemmXdlSplitKCShuffle
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
} }
......
...@@ -366,7 +366,10 @@ struct DeviceGroupedGemmXdl ...@@ -366,7 +366,10 @@ struct DeviceGroupedGemmXdl
{ {
using Argument = DeviceGroupedGemmXdl::Argument; using Argument = DeviceGroupedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_arg_arg; StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_arg_arg;
...@@ -438,8 +441,8 @@ struct DeviceGroupedGemmXdl ...@@ -438,8 +441,8 @@ struct DeviceGroupedGemmXdl
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
gemm_desc_kernel_arg_arg, gemm_desc_kernel_arg_arg,
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
...@@ -464,8 +467,8 @@ struct DeviceGroupedGemmXdl ...@@ -464,8 +467,8 @@ struct DeviceGroupedGemmXdl
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
gemm_desc_kernel_arg_arg, gemm_desc_kernel_arg_arg,
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
...@@ -477,7 +480,10 @@ struct DeviceGroupedGemmXdl ...@@ -477,7 +480,10 @@ struct DeviceGroupedGemmXdl
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
} }
......
...@@ -204,7 +204,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -204,7 +204,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType, using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType, OutDataType,
...@@ -246,8 +249,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -246,8 +249,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
arg.a_grid_desc_m_k_, arg.a_grid_desc_m_k_,
arg.b_grid_desc_m_, arg.b_grid_desc_m_,
arg.in_element_op_, arg.in_element_op_,
...@@ -259,7 +262,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -259,7 +262,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
arg.p_out_indices_dev_); arg.p_out_indices_dev_);
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
} }
......
...@@ -211,7 +211,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -211,7 +211,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto in_grid_desc_m_k = const auto in_grid_desc_m_k =
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
...@@ -259,7 +262,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -259,7 +262,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
in_grid_desc_m_k, in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
arg.in_elementwise_op_, arg.in_elementwise_op_,
...@@ -274,7 +277,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -274,7 +277,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
}; };
......
...@@ -182,7 +182,10 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -182,7 +182,10 @@ struct DeviceReduceBlockWiseSecondCall
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_); arg.inLengths_, arg.inStrides_);
...@@ -229,8 +232,8 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -229,8 +232,8 @@ struct DeviceReduceBlockWiseSecondCall
dim3(arg.gridSize), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
in_grid_desc_m_k, in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
arg.in_elementwise_op_, arg.in_elementwise_op_,
...@@ -245,7 +248,10 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -245,7 +248,10 @@ struct DeviceReduceBlockWiseSecondCall
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
}; };
......
...@@ -245,7 +245,8 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -245,7 +245,8 @@ struct DeviceReduceMultiBlockAtomicAdd
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool = false) float
Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool = false)
{ {
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
...@@ -301,7 +302,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -301,7 +302,7 @@ struct DeviceReduceMultiBlockAtomicAdd
dim3(arg.gridSize_pre), dim3(arg.gridSize_pre),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
out_grid_desc_m, out_grid_desc_m,
arg.out_dev_, arg.out_dev_,
static_cast<OutDataType>(0.0f)); static_cast<OutDataType>(0.0f));
...@@ -310,7 +311,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -310,7 +311,7 @@ struct DeviceReduceMultiBlockAtomicAdd
dim3(arg.gridSize), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
in_grid_desc_m_k, in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
arg.in_elementwise_op_, arg.in_elementwise_op_,
...@@ -329,7 +330,10 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -329,7 +330,10 @@ struct DeviceReduceMultiBlockAtomicAdd
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
}; };
......
...@@ -273,7 +273,10 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -273,7 +273,10 @@ struct DeviceReduceMultiBlockPartialReduce
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
...@@ -318,8 +321,8 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -318,8 +321,8 @@ struct DeviceReduceMultiBlockPartialReduce
dim3(arg.gridSize), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
in_grid_desc_m_k, in_grid_desc_m_k,
ws_desc_m_k, ws_desc_m_k,
arg.in_elementwise_op_, arg.in_elementwise_op_,
...@@ -333,7 +336,10 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -333,7 +336,10 @@ struct DeviceReduceMultiBlockPartialReduce
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
}; };
......
...@@ -212,7 +212,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -212,7 +212,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto in_grid_desc_m_k = const auto in_grid_desc_m_k =
DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
...@@ -259,8 +262,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -259,8 +262,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
dim3(arg.gridSize), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
stream_id, stream_id,
measure_time, measure_time,
in_grid_desc_m_k, in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
arg.in_elementwise_op_, arg.in_elementwise_op_,
...@@ -274,7 +277,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -274,7 +277,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
}; };
......
...@@ -9,33 +9,46 @@ ...@@ -9,33 +9,46 @@
struct DeviceConvFwdPtr_t struct DeviceConvFwdPtr_t
{ {
using BaseArgument = ck::tensor_operation::device::BaseArgument; using BaseArgument = ck::tensor_operation::device::BaseArgument;
using BaseInvoker = ck::tensor_operation::device::BaseInvoker; using BaseInvoker = ck::tensor_operation::device::BaseInvoker;
struct DeviceConvFwdPtrImpl; struct DeviceConvFwdPtrImpl;
std::unique_ptr<DeviceConvFwdPtrImpl> pImpl; std::unique_ptr<DeviceConvFwdPtrImpl> pImpl;
DeviceConvFwdPtr_t(); DeviceConvFwdPtr_t();
~DeviceConvFwdPtr_t(); ~DeviceConvFwdPtr_t();
DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&); DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&);
DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&); DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&);
DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete; DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete;
DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&)=delete; DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&) = delete;
std::unique_ptr<BaseArgument> MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr, std::unique_ptr<BaseArgument>
size_t N, size_t K, size_t C, MakeArgumentPointer(void* in_ptr,
std::vector<ck::index_t> input_spatial_lengths, void* wei_ptr,
std::vector<ck::index_t> filter_spatial_lengths, void* out_ptr,
std::vector<ck::index_t> output_spatial_lengths, size_t N,
std::vector<ck::index_t> conv_filter_strides, size_t K,
std::vector<ck::index_t> conv_filter_dilations, size_t C,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> input_right_pads) const; // in,wei and out element ops are ignored for now since even if we change them, they cant be linked std::vector<ck::index_t> filter_spatial_lengths,
std::unique_ptr<BaseInvoker> MakeInvokerPointer() const; // requires including BaseInvoker headers std::vector<ck::index_t> output_spatial_lengths,
std::string GetTypeString(); std::vector<ck::index_t> conv_filter_strides,
bool IsSupportedArgument(const BaseArgument* arg_ptr); std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
const; // in,wei and out element ops are ignored for now since even if we change them, they
// cant be linked
std::unique_ptr<BaseInvoker>
MakeInvokerPointer() const; // requires including BaseInvoker headers
std::string GetTypeString();
bool IsSupportedArgument(const BaseArgument* arg_ptr);
}; };
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
...@@ -10,31 +10,30 @@ ...@@ -10,31 +10,30 @@
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
inline void hip_check(hipError_t x) inline void hip_check(hipError_t x)
{ {
if(x != hipSuccess) if(x != hipSuccess)
throw std::runtime_error("Failed to run HIP call"); throw std::runtime_error("Failed to run HIP call");
} }
template<typename F, F f> template <typename F, F f>
struct managed_deleter struct managed_deleter
{ {
template<typename T> template <typename T>
void operator()(T * t) void operator()(T* t)
{ {
if(t != nullptr) if(t != nullptr)
{ {
std::ignore = f(t); std::ignore = f(t);
} }
} }
}; };
template<typename T, typename F, F f> template <typename T, typename F, F f>
using managed_pointer = std::unique_ptr<T, managed_deleter<F, f>>; using managed_pointer = std::unique_ptr<T, managed_deleter<F, f>>;
using hipEventPtr = managed_pointer<typename std::remove_pointer<hipEvent_t>::type, decltype(&hipEventDestroy), hipEventDestroy>; using hipEventPtr = managed_pointer<typename std::remove_pointer<hipEvent_t>::type,
decltype(&hipEventDestroy),
hipEventDestroy>;
inline hipEventPtr make_hip_event() inline hipEventPtr make_hip_event()
{ {
...@@ -74,14 +73,25 @@ struct KernelTimer ...@@ -74,14 +73,25 @@ struct KernelTimer
using device_stream_t = hipStream_t; using device_stream_t = hipStream_t;
template <typename... Args, typename F> template <typename... Args, typename F>
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, hipStream_t stream_id, Args... args) void launch_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
Args... args)
{ {
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
} }
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel( float launch_and_time_kernel(F kernel,
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, hipStream_t stream_id, bool measure_time, Args... args) int nrepeat,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
bool measure_time,
Args... args)
{ {
#if CK_TIME_KERNELS #if CK_TIME_KERNELS
KernelTimer timer; KernelTimer timer;
...@@ -113,14 +123,14 @@ float launch_and_time_kernel( ...@@ -113,14 +123,14 @@ float launch_and_time_kernel(
return timer.GetElapsedTime() / nrepeat; return timer.GetElapsedTime() / nrepeat;
#else #else
std::ignore = nrepeat; std::ignore = nrepeat;
hipEventPtr start = nullptr; hipEventPtr start = nullptr;
hipEventPtr stop = nullptr; hipEventPtr stop = nullptr;
float elapsed_time = 0.0f; float elapsed_time = 0.0f;
if(measure_time) if(measure_time)
{ {
start = make_hip_event(); start = make_hip_event();
stop = make_hip_event(); stop = make_hip_event();
hip_check(hipEventRecord(start.get(), stream_id)); hip_check(hipEventRecord(start.get(), stream_id));
} }
......
...@@ -29,28 +29,44 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -29,28 +29,44 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
{ {
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr, std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument>
size_t N, size_t K, size_t C, MakeArgumentPointer(void* in_ptr,
std::vector<ck::index_t> input_spatial_lengths, void* wei_ptr,
std::vector<ck::index_t> filter_spatial_lengths, void* out_ptr,
std::vector<ck::index_t> output_spatial_lengths, size_t N,
std::vector<ck::index_t> conv_filter_strides, size_t K,
std::vector<ck::index_t> conv_filter_dilations, size_t C,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> input_right_pads) const std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) const
{ {
return el->MakeArgumentPointer(in_ptr, wei_ptr, out_ptr, N, K, C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides, return el->MakeArgumentPointer(in_ptr,
conv_filter_dilations, input_left_pads, input_right_pads, PassThrough{}, PassThrough{}, PassThrough{}); wei_ptr,
out_ptr,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
} }
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> MakeInvokerPointer() const std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> MakeInvokerPointer() const
{ {
return el->MakeInvokerPointer(); return el->MakeInvokerPointer();
} }
std::string GetTypeString() std::string GetTypeString() { return el->GetTypeString(); }
{
return el->GetTypeString();
}
bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg) bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg)
{ {
return el->IsSupportedArgument(arg); return el->IsSupportedArgument(arg);
...@@ -59,24 +75,44 @@ struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl ...@@ -59,24 +75,44 @@ struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough> el; ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough> el;
}; };
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr){} DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr) {}
// DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& impl) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(impl)) {} // DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& impl) :
DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default; // pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(impl)) {}
DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default; DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(std::move(other))){} DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other)
: pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(std::move(other)))
{
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr, std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument>
size_t N, size_t K, size_t C, DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr,
std::vector<ck::index_t> input_spatial_lengths, void* wei_ptr,
std::vector<ck::index_t> filter_spatial_lengths, void* out_ptr,
std::vector<ck::index_t> output_spatial_lengths, size_t N,
std::vector<ck::index_t> conv_filter_strides, size_t K,
std::vector<ck::index_t> conv_filter_dilations, size_t C,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> input_right_pads) const std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) const
{ {
return pImpl->MakeArgumentPointer(in_ptr, wei_ptr, out_ptr, N, K, C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides, return pImpl->MakeArgumentPointer(in_ptr,
conv_filter_dilations, input_left_pads, input_right_pads); wei_ptr,
out_ptr,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
} }
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvokerPointer() const std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvokerPointer() const
...@@ -84,21 +120,21 @@ std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvoker ...@@ -84,21 +120,21 @@ std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvoker
return pImpl->MakeInvokerPointer(); return pImpl->MakeInvokerPointer();
} }
std::string DeviceConvFwdPtr_t::GetTypeString() std::string DeviceConvFwdPtr_t::GetTypeString() { return pImpl->GetTypeString(); }
{
return pImpl->GetTypeString();
}
bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr) bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr)
{ {
return pImpl->IsSupportedArgument(arg_ptr); return pImpl->IsSupportedArgument(arg_ptr);
} }
using namespace ck::tensor_operation::device::device_conv2d_fwd_instance; using namespace ck::tensor_operation::device::device_conv2d_fwd_instance;
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(local_instances); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(local_instances);
for(auto& kinder: local_instances) for(auto& kinder : local_instances)
{ {
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); instances.emplace_back(tmp);
...@@ -106,11 +142,14 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(std::vec ...@@ -106,11 +142,14 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(std::vec
return; return;
} }
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances);
for(auto& kinder: local_instances) for(auto& kinder : local_instances)
{ {
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better instances.emplace_back(tmp); // Perhaps we can do better
...@@ -118,11 +157,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<Device ...@@ -118,11 +157,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<Device
return; return;
} }
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(local_instances); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(local_instances);
for(auto& kinder: local_instances) for(auto& kinder : local_instances)
{ {
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better instances.emplace_back(tmp); // Perhaps we can do better
...@@ -130,25 +172,29 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(std::vector<Devic ...@@ -130,25 +172,29 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(std::vector<Devic
return; return;
} }
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(local_instances); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(local_instances);
for(auto& kinder: local_instances) for(auto& kinder : local_instances)
{ {
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better instances.emplace_back(tmp); // Perhaps we can do better
} }
return; return;
} }
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(local_instances); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(local_instances);
for(auto& kinder: local_instances) for(auto& kinder : local_instances)
{ {
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); instances.emplace_back(tmp);
......
...@@ -26,14 +26,14 @@ int main(int argc, char* argv[]) ...@@ -26,14 +26,14 @@ int main(int argc, char* argv[])
exit(1); exit(1);
} }
const ConvDataType data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const ConvDataType data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
const int nrepeat = std::stoi(argv[9]); const int nrepeat = std::stoi(argv[9]);
const ck::index_t N = std::stoi(argv[10]); const ck::index_t N = std::stoi(argv[10]);
const ck::index_t K = std::stoi(argv[11]); const ck::index_t K = std::stoi(argv[11]);
...@@ -58,21 +58,20 @@ int main(int argc, char* argv[]) ...@@ -58,21 +58,20 @@ int main(int argc, char* argv[])
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
ck::app::profile_conv_fwd_impl( ck::app::profile_conv_fwd_impl(do_verification,
do_verification, init_method,
init_method, do_log,
do_log, nrepeat,
nrepeat, data_type,
data_type, N,
N, K,
K, C,
C, std::vector<ck::index_t>{Hi, Wi},
std::vector<ck::index_t>{Hi, Wi}, std::vector<ck::index_t>{Y, X},
std::vector<ck::index_t>{Y, X}, std::vector<ck::index_t>{Ho, Wo},
std::vector<ck::index_t>{Ho, Wo}, std::vector<ck::index_t>{conv_stride_h, conv_stride_w},
std::vector<ck::index_t>{conv_stride_h, conv_stride_w}, std::vector<ck::index_t>{conv_dilation_h, conv_dilation_w},
std::vector<ck::index_t>{conv_dilation_h, conv_dilation_w}, std::vector<ck::index_t>{in_left_pad_h, in_left_pad_w},
std::vector<ck::index_t>{in_left_pad_h, in_left_pad_w}, std::vector<ck::index_t>{in_right_pad_h, in_right_pad_w});
std::vector<ck::index_t>{in_right_pad_h, in_right_pad_w});
return 1; return 1;
} }
...@@ -32,13 +32,10 @@ enum ConvOutputLayout ...@@ -32,13 +32,10 @@ enum ConvOutputLayout
void check_cuda_error(void) void check_cuda_error(void)
{ {
hipError_t err = hipGetLastError(); hipError_t err = hipGetLastError();
if (err != hipSuccess) if(err != hipSuccess)
{ {
std::cerr std::cerr << "Error: " << hipGetErrorString(err) << std::endl;
<< "Error: " exit(err);
<< hipGetErrorString(err)
<< std::endl;
exit(err);
} }
} }
std::string getDeviceName(int device) std::string getDeviceName(int device)
...@@ -57,8 +54,6 @@ int getDriver(void) ...@@ -57,8 +54,6 @@ int getDriver(void)
return driver; return driver;
} }
namespace ck { namespace ck {
namespace app { namespace app {
struct DeviceMem struct DeviceMem
...@@ -119,12 +114,12 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -119,12 +114,12 @@ void profile_conv_fwd_impl(int do_verification,
const ck::index_t Ho = output_spatial_lengths[0]; const ck::index_t Ho = output_spatial_lengths[0];
const ck::index_t Wo = output_spatial_lengths[1]; const ck::index_t Wo = output_spatial_lengths[1];
const auto in_sz = N * C * Hi * Wi; const auto in_sz = N * C * Hi * Wi;
const auto wei_sz = K * C * Y * X; const auto wei_sz = K * C * Y * X;
const auto out_sz = N * K * Ho * Wo; const auto out_sz = N * K * Ho * Wo;
using WeiDataType = float; using WeiDataType = float;
using InDataType = float; using InDataType = float;
using OutDataType = float; using OutDataType = float;
app::DeviceMem in_device_buf(sizeof(InDataType) * in_sz); app::DeviceMem in_device_buf(sizeof(InDataType) * in_sz);
...@@ -132,7 +127,6 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -132,7 +127,6 @@ void profile_conv_fwd_impl(int do_verification,
app::DeviceMem out_device_buf(sizeof(OutDataType) * out_sz); app::DeviceMem out_device_buf(sizeof(OutDataType) * out_sz);
// data is already on device! // data is already on device!
// add device Conv instances // add device Conv instances
std::vector<DeviceConvFwdPtr_t> conv_ptrs; std::vector<DeviceConvFwdPtr_t> conv_ptrs;
if(data_type == F16_F16_F16) if(data_type == F16_F16_F16)
...@@ -157,11 +151,10 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -157,11 +151,10 @@ void profile_conv_fwd_impl(int do_verification,
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
int deviceIndex = 0; int deviceIndex = 0;
hipSetDevice(deviceIndex); hipSetDevice(deviceIndex);
check_cuda_error(); check_cuda_error();
hipStream_t stream_id = nullptr; hipStream_t stream_id = nullptr;
hipStreamCreate(&stream_id); hipStreamCreate(&stream_id);
check_cuda_error(); check_cuda_error();
...@@ -169,27 +162,27 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -169,27 +162,27 @@ void profile_conv_fwd_impl(int do_verification,
// profile device Conv instances // profile device Conv instances
for(auto& conv_ptr : conv_ptrs) for(auto& conv_ptr : conv_ptrs)
{ {
auto argument_ptr = conv_ptr.MakeArgumentPointer( auto argument_ptr =
static_cast<void*>(in_device_buf.GetDeviceBuffer()), conv_ptr.MakeArgumentPointer(static_cast<void*>(in_device_buf.GetDeviceBuffer()),
static_cast<void*>(wei_device_buf.GetDeviceBuffer()), static_cast<void*>(wei_device_buf.GetDeviceBuffer()),
static_cast<void*>(out_device_buf.GetDeviceBuffer()), static_cast<void*>(out_device_buf.GetDeviceBuffer()),
N, N,
K, K,
C, C,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
auto invoker_ptr = conv_ptr.MakeInvokerPointer(); auto invoker_ptr = conv_ptr.MakeInvokerPointer();
if(conv_ptr.IsSupportedArgument(argument_ptr.get())) if(conv_ptr.IsSupportedArgument(argument_ptr.get()))
{ {
std::string conv_name = conv_ptr.GetTypeString(); std::string conv_name = conv_ptr.GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat, stream_id, true); float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat, stream_id, true);
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
...@@ -218,5 +211,5 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -218,5 +211,5 @@ void profile_conv_fwd_impl(int do_verification,
<< best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
} }
} // namespace profiler } // namespace app
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment