Commit ffa70551 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Fix formatting

parent 29e1829f
...@@ -435,7 +435,10 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -435,7 +435,10 @@ struct DeviceGemm_Xdl_CShuffle
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
#if 0 #if 0
{ {
...@@ -540,7 +543,10 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -540,7 +543,10 @@ struct DeviceGemm_Xdl_CShuffle
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
} }
......
...@@ -385,7 +385,10 @@ struct DeviceGemmXdlSplitK ...@@ -385,7 +385,10 @@ struct DeviceGemmXdlSplitK
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
...@@ -534,7 +537,10 @@ struct DeviceGemmXdlSplitK ...@@ -534,7 +537,10 @@ struct DeviceGemmXdlSplitK
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
} }
......
...@@ -391,7 +391,10 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -391,7 +391,10 @@ struct DeviceGemmXdlSplitKCShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
...@@ -545,7 +548,10 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -545,7 +548,10 @@ struct DeviceGemmXdlSplitKCShuffle
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
} }
......
...@@ -366,7 +366,10 @@ struct DeviceGroupedGemmXdl ...@@ -366,7 +366,10 @@ struct DeviceGroupedGemmXdl
{ {
using Argument = DeviceGroupedGemmXdl::Argument; using Argument = DeviceGroupedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_arg_arg; StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_arg_arg;
...@@ -477,7 +480,10 @@ struct DeviceGroupedGemmXdl ...@@ -477,7 +480,10 @@ struct DeviceGroupedGemmXdl
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
} }
......
...@@ -204,7 +204,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -204,7 +204,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType, using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType, OutDataType,
...@@ -259,7 +262,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -259,7 +262,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
arg.p_out_indices_dev_); arg.p_out_indices_dev_);
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
} }
......
...@@ -211,7 +211,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -211,7 +211,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto in_grid_desc_m_k = const auto in_grid_desc_m_k =
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
...@@ -274,7 +277,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -274,7 +277,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
}; };
......
...@@ -182,7 +182,10 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -182,7 +182,10 @@ struct DeviceReduceBlockWiseSecondCall
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_); arg.inLengths_, arg.inStrides_);
...@@ -245,7 +248,10 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -245,7 +248,10 @@ struct DeviceReduceBlockWiseSecondCall
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
}; };
......
...@@ -245,7 +245,8 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -245,7 +245,8 @@ struct DeviceReduceMultiBlockAtomicAdd
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool = false) float
Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool = false)
{ {
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
...@@ -329,7 +330,10 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -329,7 +330,10 @@ struct DeviceReduceMultiBlockAtomicAdd
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
}; };
......
...@@ -273,7 +273,10 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -273,7 +273,10 @@ struct DeviceReduceMultiBlockPartialReduce
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
...@@ -333,7 +336,10 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -333,7 +336,10 @@ struct DeviceReduceMultiBlockPartialReduce
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
}; };
......
...@@ -212,7 +212,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -212,7 +212,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) float Run(const Argument& arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false)
{ {
const auto in_grid_desc_m_k = const auto in_grid_desc_m_k =
DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
...@@ -274,7 +277,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -274,7 +277,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1, hipStream_t stream_id = nullptr, bool measure_time = false) override float Run(const BaseArgument* p_arg,
int nrepeat = 1,
hipStream_t stream_id = nullptr,
bool measure_time = false) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time); return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat, stream_id, measure_time);
}; };
......
...@@ -19,23 +19,36 @@ struct DeviceConvFwdPtr_t ...@@ -19,23 +19,36 @@ struct DeviceConvFwdPtr_t
DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&); DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&);
DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&); DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&);
DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete; DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete;
DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&)=delete; DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&) = delete;
std::unique_ptr<BaseArgument> MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr, std::unique_ptr<BaseArgument>
size_t N, size_t K, size_t C, MakeArgumentPointer(void* in_ptr,
void* wei_ptr,
void* out_ptr,
size_t N,
size_t K,
size_t C,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) const; // in,wei and out element ops are ignored for now since even if we change them, they cant be linked std::vector<ck::index_t> input_right_pads)
std::unique_ptr<BaseInvoker> MakeInvokerPointer() const; // requires including BaseInvoker headers const; // in,wei and out element ops are ignored for now since even if we change them, they
// cant be linked
std::unique_ptr<BaseInvoker>
MakeInvokerPointer() const; // requires including BaseInvoker headers
std::string GetTypeString(); std::string GetTypeString();
bool IsSupportedArgument(const BaseArgument* arg_ptr); bool IsSupportedArgument(const BaseArgument* arg_ptr);
}; };
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
...@@ -10,31 +10,30 @@ ...@@ -10,31 +10,30 @@
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
inline void hip_check(hipError_t x) inline void hip_check(hipError_t x)
{ {
if(x != hipSuccess) if(x != hipSuccess)
throw std::runtime_error("Failed to run HIP call"); throw std::runtime_error("Failed to run HIP call");
} }
template<typename F, F f> template <typename F, F f>
struct managed_deleter struct managed_deleter
{ {
template<typename T> template <typename T>
void operator()(T * t) void operator()(T* t)
{ {
if(t != nullptr) if(t != nullptr)
{ {
std::ignore = f(t); std::ignore = f(t);
} }
} }
}; };
template<typename T, typename F, F f> template <typename T, typename F, F f>
using managed_pointer = std::unique_ptr<T, managed_deleter<F, f>>; using managed_pointer = std::unique_ptr<T, managed_deleter<F, f>>;
using hipEventPtr = managed_pointer<typename std::remove_pointer<hipEvent_t>::type, decltype(&hipEventDestroy), hipEventDestroy>; using hipEventPtr = managed_pointer<typename std::remove_pointer<hipEvent_t>::type,
decltype(&hipEventDestroy),
hipEventDestroy>;
inline hipEventPtr make_hip_event() inline hipEventPtr make_hip_event()
{ {
...@@ -74,14 +73,25 @@ struct KernelTimer ...@@ -74,14 +73,25 @@ struct KernelTimer
using device_stream_t = hipStream_t; using device_stream_t = hipStream_t;
template <typename... Args, typename F> template <typename... Args, typename F>
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, hipStream_t stream_id, Args... args) void launch_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
Args... args)
{ {
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
} }
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel( float launch_and_time_kernel(F kernel,
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, hipStream_t stream_id, bool measure_time, Args... args) int nrepeat,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
bool measure_time,
Args... args)
{ {
#if CK_TIME_KERNELS #if CK_TIME_KERNELS
KernelTimer timer; KernelTimer timer;
......
...@@ -29,8 +29,13 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -29,8 +29,13 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
{ {
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr, std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument>
size_t N, size_t K, size_t C, MakeArgumentPointer(void* in_ptr,
void* wei_ptr,
void* out_ptr,
size_t N,
size_t K,
size_t C,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> output_spatial_lengths,
...@@ -39,18 +44,29 @@ struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl ...@@ -39,18 +44,29 @@ struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) const std::vector<ck::index_t> input_right_pads) const
{ {
return el->MakeArgumentPointer(in_ptr, wei_ptr, out_ptr, N, K, C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides, return el->MakeArgumentPointer(in_ptr,
conv_filter_dilations, input_left_pads, input_right_pads, PassThrough{}, PassThrough{}, PassThrough{}); wei_ptr,
out_ptr,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
} }
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> MakeInvokerPointer() const std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> MakeInvokerPointer() const
{ {
return el->MakeInvokerPointer(); return el->MakeInvokerPointer();
} }
std::string GetTypeString() std::string GetTypeString() { return el->GetTypeString(); }
{
return el->GetTypeString();
}
bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg) bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg)
{ {
return el->IsSupportedArgument(arg); return el->IsSupportedArgument(arg);
...@@ -59,14 +75,23 @@ struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl ...@@ -59,14 +75,23 @@ struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough> el; ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough> el;
}; };
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr){} DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr) {}
// DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& impl) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(impl)) {} // DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& impl) :
// pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(impl)) {}
DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default; DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default; DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(std::move(other))){} DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other)
: pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(std::move(other)))
{
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr, std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument>
size_t N, size_t K, size_t C, DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr,
void* wei_ptr,
void* out_ptr,
size_t N,
size_t K,
size_t C,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> output_spatial_lengths,
...@@ -75,8 +100,19 @@ std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> DeviceConvFwdPtr_t::MakeArgume ...@@ -75,8 +100,19 @@ std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> DeviceConvFwdPtr_t::MakeArgume
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) const std::vector<ck::index_t> input_right_pads) const
{ {
return pImpl->MakeArgumentPointer(in_ptr, wei_ptr, out_ptr, N, K, C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides, return pImpl->MakeArgumentPointer(in_ptr,
conv_filter_dilations, input_left_pads, input_right_pads); wei_ptr,
out_ptr,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
} }
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvokerPointer() const std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvokerPointer() const
...@@ -84,21 +120,21 @@ std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvoker ...@@ -84,21 +120,21 @@ std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvoker
return pImpl->MakeInvokerPointer(); return pImpl->MakeInvokerPointer();
} }
std::string DeviceConvFwdPtr_t::GetTypeString() std::string DeviceConvFwdPtr_t::GetTypeString() { return pImpl->GetTypeString(); }
{
return pImpl->GetTypeString();
}
bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr) bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr)
{ {
return pImpl->IsSupportedArgument(arg_ptr); return pImpl->IsSupportedArgument(arg_ptr);
} }
using namespace ck::tensor_operation::device::device_conv2d_fwd_instance; using namespace ck::tensor_operation::device::device_conv2d_fwd_instance;
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(local_instances); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(local_instances);
for(auto& kinder: local_instances) for(auto& kinder : local_instances)
{ {
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); instances.emplace_back(tmp);
...@@ -106,11 +142,14 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(std::vec ...@@ -106,11 +142,14 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(std::vec
return; return;
} }
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances);
for(auto& kinder: local_instances) for(auto& kinder : local_instances)
{ {
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better instances.emplace_back(tmp); // Perhaps we can do better
...@@ -118,11 +157,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<Device ...@@ -118,11 +157,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<Device
return; return;
} }
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(local_instances); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(local_instances);
for(auto& kinder: local_instances) for(auto& kinder : local_instances)
{ {
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better instances.emplace_back(tmp); // Perhaps we can do better
...@@ -130,25 +172,29 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(std::vector<Devic ...@@ -130,25 +172,29 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(std::vector<Devic
return; return;
} }
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(local_instances); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(local_instances);
for(auto& kinder: local_instances) for(auto& kinder : local_instances)
{ {
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better instances.emplace_back(tmp); // Perhaps we can do better
} }
return; return;
} }
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(local_instances); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(local_instances);
for(auto& kinder: local_instances) for(auto& kinder : local_instances)
{ {
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); instances.emplace_back(tmp);
......
...@@ -58,8 +58,7 @@ int main(int argc, char* argv[]) ...@@ -58,8 +58,7 @@ int main(int argc, char* argv[])
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
ck::app::profile_conv_fwd_impl( ck::app::profile_conv_fwd_impl(do_verification,
do_verification,
init_method, init_method,
do_log, do_log,
nrepeat, nrepeat,
......
...@@ -32,12 +32,9 @@ enum ConvOutputLayout ...@@ -32,12 +32,9 @@ enum ConvOutputLayout
void check_cuda_error(void) void check_cuda_error(void)
{ {
hipError_t err = hipGetLastError(); hipError_t err = hipGetLastError();
if (err != hipSuccess) if(err != hipSuccess)
{ {
std::cerr std::cerr << "Error: " << hipGetErrorString(err) << std::endl;
<< "Error: "
<< hipGetErrorString(err)
<< std::endl;
exit(err); exit(err);
} }
} }
...@@ -57,8 +54,6 @@ int getDriver(void) ...@@ -57,8 +54,6 @@ int getDriver(void)
return driver; return driver;
} }
namespace ck { namespace ck {
namespace app { namespace app {
struct DeviceMem struct DeviceMem
...@@ -132,7 +127,6 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -132,7 +127,6 @@ void profile_conv_fwd_impl(int do_verification,
app::DeviceMem out_device_buf(sizeof(OutDataType) * out_sz); app::DeviceMem out_device_buf(sizeof(OutDataType) * out_sz);
// data is already on device! // data is already on device!
// add device Conv instances // add device Conv instances
std::vector<DeviceConvFwdPtr_t> conv_ptrs; std::vector<DeviceConvFwdPtr_t> conv_ptrs;
if(data_type == F16_F16_F16) if(data_type == F16_F16_F16)
...@@ -161,7 +155,6 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -161,7 +155,6 @@ void profile_conv_fwd_impl(int do_verification,
hipSetDevice(deviceIndex); hipSetDevice(deviceIndex);
check_cuda_error(); check_cuda_error();
hipStream_t stream_id = nullptr; hipStream_t stream_id = nullptr;
hipStreamCreate(&stream_id); hipStreamCreate(&stream_id);
check_cuda_error(); check_cuda_error();
...@@ -169,8 +162,8 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -169,8 +162,8 @@ void profile_conv_fwd_impl(int do_verification,
// profile device Conv instances // profile device Conv instances
for(auto& conv_ptr : conv_ptrs) for(auto& conv_ptr : conv_ptrs)
{ {
auto argument_ptr = conv_ptr.MakeArgumentPointer( auto argument_ptr =
static_cast<void*>(in_device_buf.GetDeviceBuffer()), conv_ptr.MakeArgumentPointer(static_cast<void*>(in_device_buf.GetDeviceBuffer()),
static_cast<void*>(wei_device_buf.GetDeviceBuffer()), static_cast<void*>(wei_device_buf.GetDeviceBuffer()),
static_cast<void*>(out_device_buf.GetDeviceBuffer()), static_cast<void*>(out_device_buf.GetDeviceBuffer()),
N, N,
...@@ -218,5 +211,5 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -218,5 +211,5 @@ void profile_conv_fwd_impl(int do_verification,
<< best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
} }
} // namespace profiler } // namespace app
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment