Commit c3738ce3 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Fixes

parent eb898ad6
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 PRIVATE composable_kernel::device_operations)
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <vector> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleaddx2_relu.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -50,11 +50,11 @@ struct SimpleDeviceMem ...@@ -50,11 +50,11 @@ struct SimpleDeviceMem
void* p_mem_; void* p_mem_;
}; };
int main() int execute_conv_fwd_scaleadd_scaleadd_relu()
{ {
// We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space.
// However, CK's API only accept length and stride with order of GNCDHW/GKCZYX/GNKDHW // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW.
// Hence, we need to adjust the order of stride // Hence, we need to adjust the order of strides.
std::array<ck::index_t, 6> in_lengths{G, N, C, Di, Hi, Wi}; std::array<ck::index_t, 6> in_lengths{G, N, C, Di, Hi, Wi};
std::array<ck::index_t, 6> in_strides{ std::array<ck::index_t, 6> in_strides{
C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
...@@ -208,4 +208,5 @@ int main() ...@@ -208,4 +208,5 @@ int main()
std::cout << "Done" << std::endl; std::cout << "Done" << std::endl;
} }
return 0;
} }
...@@ -8,8 +8,10 @@ ...@@ -8,8 +8,10 @@
using InDataType = ck::bhalf_t; using InDataType = ck::bhalf_t;
using WeiDataType = ck::bhalf_t; using WeiDataType = ck::bhalf_t;
using OutDataType = ck::bhalf_t; using OutDataType = ck::bhalf_t;
// Use std tuple instead ck tuple to avoid clang // Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error. // implicit instantiation of undefined template error.
using DDataTypes = std::tuple<ck::bhalf_t, ck::bhalf_t>; using DDataTypes = std::tuple<ck::bhalf_t, ck::bhalf_t>;
#include "grouped_conv_fwd_scaleaddx2_relu.inc" #include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
...@@ -8,8 +8,10 @@ ...@@ -8,8 +8,10 @@
using InDataType = ck::half_t; using InDataType = ck::half_t;
using WeiDataType = ck::half_t; using WeiDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
// Use std tuple instead ck tuple to avoid clang // Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error. // implicit instantiation of undefined template error.
using DDataTypes = std::tuple<ck::half_t, ck::half_t>; using DDataTypes = std::tuple<ck::half_t, ck::half_t>;
#include "grouped_conv_fwd_scaleaddx2_relu.inc" #include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
...@@ -8,8 +8,10 @@ ...@@ -8,8 +8,10 @@
using InDataType = float; using InDataType = float;
using WeiDataType = float; using WeiDataType = float;
using OutDataType = float; using OutDataType = float;
// Use std tuple instead ck tuple to avoid clang // Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error. // implicit instantiation of undefined template error.
using DDataTypes = std::tuple<float, float>; using DDataTypes = std::tuple<float, float>;
#include "grouped_conv_fwd_scaleaddx2_relu.inc" #include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
...@@ -8,8 +8,10 @@ ...@@ -8,8 +8,10 @@
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
using OutDataType = int8_t; using OutDataType = int8_t;
// Use std tuple instead ck tuple to avoid clang // Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error. // implicit instantiation of undefined template error.
using DDataTypes = std::tuple<float, float>; using DDataTypes = std::tuple<float, float>;
#include "grouped_conv_fwd_scaleaddx2_relu.inc" #include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
add_executable(client_grouped_convnd_fwd_scaleaddx2_relu_fp32 grouped_conv_fwd_scaleaddx2_relu_fp32.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleaddx2_relu_fp32 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleaddx2_relu_fp16 grouped_conv_fwd_scaleaddx2_relu_fp16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleaddx2_relu_fp16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleaddx2_relu_bf16 grouped_conv_fwd_scaleaddx2_relu_bf16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleaddx2_relu_bf16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleaddx2_relu_int8 grouped_conv_fwd_scaleaddx2_relu_int8.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleaddx2_relu_int8 PRIVATE composable_kernel::device_operations)
...@@ -31,8 +31,8 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -31,8 +31,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16) add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16)
# ScaleAdd ScaleAdd Relu # ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleaddx2_relu_fp16 convnd_fwd_xdl_scaleaddx2_relu_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleaddx2_relu_fp16) add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
...@@ -190,9 +190,8 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -190,9 +190,8 @@ bool run_grouped_conv_fwd(bool do_verification,
if(!conv.IsSupportedArgument(argument)) if(!conv.IsSupportedArgument(argument))
{ {
throw std::runtime_error( throw std::runtime_error("The device op with the specified compilation parameters does "
"wrong! device_conv with the specified compilation parameters does " "not support this convolution problem.");
"not support this Conv problem");
} }
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
...@@ -97,7 +97,7 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -97,7 +97,7 @@ using DeviceGroupedConvNDFwdInstance =
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>; using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
namespace { namespace {
// Use own implementation to pass two more tensors for post op // Use custom implementation to pass two more tensors for post op
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
...@@ -181,7 +181,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -181,7 +181,7 @@ bool run_grouped_conv_fwd(bool do_verification,
copy(conv_param.input_right_pads_, input_right_pads); copy(conv_param.input_right_pads_, input_right_pads);
const std::array<const void*, NumDs> ds = {d0_buf.GetDeviceBuffer(), d1_buf.GetDeviceBuffer()}; const std::array<const void*, NumDs> ds = {d0_buf.GetDeviceBuffer(), d1_buf.GetDeviceBuffer()};
// do Conv
auto conv = DeviceConvNDFwdInstance{}; auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
...@@ -208,9 +208,8 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -208,9 +208,8 @@ bool run_grouped_conv_fwd(bool do_verification,
if(!conv.IsSupportedArgument(argument)) if(!conv.IsSupportedArgument(argument))
{ {
throw std::runtime_error( throw std::runtime_error("The device op with the specified compilation parameters does "
"wrong! device_conv with the specified compilation parameters does " "not support this convolution problem.");
"not support this Conv problem");
} }
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
...@@ -59,11 +59,11 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -59,11 +59,11 @@ struct ReferenceConvFwd : public device::BaseOperator
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
const std::array<Tensor<OutDataType>, NumDTensor>& postop_tensors) const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors)
: input_{input}, : input_{input},
weight_{weight}, weight_{weight},
output_{output}, output_{output},
postop_tensors_{postop_tensors}, d_tensors_{d_tensors},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations}, conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{input_left_pads},
...@@ -78,7 +78,7 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -78,7 +78,7 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<WeiDataType>& weight_; const Tensor<WeiDataType>& weight_;
Tensor<OutDataType>& output_; Tensor<OutDataType>& output_;
const std::array<Tensor<OutDataType>, NumDTensor>& postop_tensors_; const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors_;
std::vector<index_t> conv_strides_; std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> conv_dilations_;
...@@ -141,19 +141,18 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -141,19 +141,18 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
else if constexpr(NumDTensor == 1) else if constexpr(NumDTensor == 1)
{ {
arg.out_element_op_( arg.out_element_op_(v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, wo));
v_out, v_acc_converted, arg.postop_tensors_[0](g, n, k, wo));
} }
else if constexpr(NumDTensor == 2) else if constexpr(NumDTensor == 2)
{ {
arg.out_element_op_(v_out, arg.out_element_op_(v_out,
v_acc_converted, v_acc_converted,
arg.postop_tensors_[0](g, n, k, wo), arg.d_tensors_[0](g, n, k, wo),
arg.postop_tensors_[1](g, n, k, wo)); arg.d_tensors_[1](g, n, k, wo));
} }
else else
{ {
throw std::runtime_error("ElementOp not supported in reference."); throw std::runtime_error("Output ElementOp not supported in reference.");
} }
arg.output_(g, n, k, wo) = v_out; arg.output_(g, n, k, wo) = v_out;
}; };
...@@ -216,18 +215,18 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -216,18 +215,18 @@ struct ReferenceConvFwd : public device::BaseOperator
else if constexpr(NumDTensor == 1) else if constexpr(NumDTensor == 1)
{ {
arg.out_element_op_( arg.out_element_op_(
v_out, v_acc_converted, arg.postop_tensors_[0](g, n, k, ho, wo)); v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, ho, wo));
} }
else if constexpr(NumDTensor == 2) else if constexpr(NumDTensor == 2)
{ {
arg.out_element_op_(v_out, arg.out_element_op_(v_out,
v_acc_converted, v_acc_converted,
arg.postop_tensors_[0](g, n, k, ho, wo), arg.d_tensors_[0](g, n, k, ho, wo),
arg.postop_tensors_[1](g, n, k, ho, wo)); arg.d_tensors_[1](g, n, k, ho, wo));
} }
else else
{ {
throw std::runtime_error("ElementOp not supported in reference."); throw std::runtime_error("Output ElementOp not supported in reference.");
} }
arg.output_(g, n, k, ho, wo) = v_out; arg.output_(g, n, k, ho, wo) = v_out;
}; };
...@@ -303,18 +302,18 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -303,18 +302,18 @@ struct ReferenceConvFwd : public device::BaseOperator
else if constexpr(NumDTensor == 1) else if constexpr(NumDTensor == 1)
{ {
arg.out_element_op_( arg.out_element_op_(
v_out, v_acc_converted, arg.postop_tensors_[0](g, n, k, d_o, ho, wo)); v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, d_o, ho, wo));
} }
else if constexpr(NumDTensor == 2) else if constexpr(NumDTensor == 2)
{ {
arg.out_element_op_(v_out, arg.out_element_op_(v_out,
v_acc_converted, v_acc_converted,
arg.postop_tensors_[0](g, n, k, d_o, ho, wo), arg.d_tensors_[0](g, n, k, d_o, ho, wo),
arg.postop_tensors_[1](g, n, k, d_o, ho, wo)); arg.d_tensors_[1](g, n, k, d_o, ho, wo));
} }
else else
{ {
throw std::runtime_error("ElementOp not supported in reference."); throw std::runtime_error("Output ElementOp not supported in reference.");
} }
arg.output_(g, n, k, d_o, ho, wo) = v_out; arg.output_(g, n, k, d_o, ho, wo) = v_out;
}; };
...@@ -360,7 +359,7 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -360,7 +359,7 @@ struct ReferenceConvFwd : public device::BaseOperator
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
const std::array<Tensor<OutDataType>, NumDTensor>& postop_tensors = {}) const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors = {})
{ {
return Argument{input, return Argument{input,
weight, weight,
...@@ -372,7 +371,7 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -372,7 +371,7 @@ struct ReferenceConvFwd : public device::BaseOperator
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
postop_tensors}; d_tensors};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -43,7 +43,7 @@ template <index_t NDimSpatial, ...@@ -43,7 +43,7 @@ template <index_t NDimSpatial,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
ConvolutionForwardSpecialization ConvSpec> ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_scaleaddx2_relu_bf16_instances = std::tuple< using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances = std::tuple<
// clang-format off // clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...@@ -65,7 +65,7 @@ template <index_t NDimSpatial, ...@@ -65,7 +65,7 @@ template <index_t NDimSpatial,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
ConvolutionForwardSpecialization ConvSpec> ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_scaleaddx2_relu_f16_instances = std::tuple< using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances = std::tuple<
// clang-format off // clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...@@ -87,7 +87,7 @@ template <index_t NDimSpatial, ...@@ -87,7 +87,7 @@ template <index_t NDimSpatial,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
ConvolutionForwardSpecialization ConvSpec> ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_scaleaddx2_relu_f32_instances = std::tuple< using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances = std::tuple<
// clang-format off // clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...@@ -109,7 +109,7 @@ template <index_t NDimSpatial, ...@@ -109,7 +109,7 @@ template <index_t NDimSpatial,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
ConvolutionForwardSpecialization ConvSpec> ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_scaleaddx2_relu_int8_instances = std::tuple< using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances = std::tuple<
// clang-format off // clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...@@ -22,7 +23,7 @@ using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAd ...@@ -22,7 +23,7 @@ using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAd
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
...@@ -38,7 +39,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16 ...@@ -38,7 +39,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
...@@ -54,7 +55,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_ ...@@ -54,7 +55,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
...@@ -70,7 +71,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_ ...@@ -70,7 +71,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_
#endif #endif
#ifdef CK_ENABLE_INT8 #ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
...@@ -135,7 +136,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -135,7 +136,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances( add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -143,7 +144,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -143,7 +144,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>) is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances( add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -151,7 +152,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -151,7 +152,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, ck::bhalf_t> && if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>) is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances( add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -159,7 +160,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -159,7 +160,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances( add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
......
set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp)
add_instance_library(device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance ${GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU})
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleaddx2_relu_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -9,7 +9,7 @@ namespace tensor_operation { ...@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
...@@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16 ...@@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_bf16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_bf16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_bf16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
} }
} // namespace instance } // namespace instance
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleaddx2_relu_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -9,7 +9,7 @@ namespace tensor_operation { ...@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
...@@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_ ...@@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
} }
} // namespace instance } // namespace instance
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleaddx2_relu_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -9,7 +9,7 @@ namespace tensor_operation { ...@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
...@@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_ ...@@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f32_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f32_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f32_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
} }
} // namespace instance } // namespace instance
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleaddx2_relu_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
...@@ -24,28 +24,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8 ...@@ -24,28 +24,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_int8_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_int8_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_int8_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
} }
} // namespace instance } // namespace instance
......
set(GROUPED_CONV3D_FWD_SCALEADDX2_RELU
xdl/device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp)
add_instance_library(device_grouped_conv3d_fwd_scaleaddx2_relu_instance ${GROUPED_CONV3D_FWD_SCALEADDX2_RELU})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment