Commit c3738ce3 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Fixes

parent eb898ad6
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 PRIVATE composable_kernel::device_operations)
......@@ -9,7 +9,7 @@
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleaddx2_relu.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -50,11 +50,11 @@ struct SimpleDeviceMem
void* p_mem_;
};
int main()
int execute_conv_fwd_scaleadd_scaleadd_relu()
{
// We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space
// However, CK's API only accept length and stride with order of GNCDHW/GKCZYX/GNKDHW
// Hence, we need to adjust the order of stride
// We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space.
// However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW.
// Hence, we need to adjust the order of strides.
std::array<ck::index_t, 6> in_lengths{G, N, C, Di, Hi, Wi};
std::array<ck::index_t, 6> in_strides{
C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
......@@ -208,4 +208,5 @@ int main()
std::cout << "Done" << std::endl;
}
return 0;
}
......@@ -8,8 +8,10 @@
using InDataType = ck::bhalf_t;
using WeiDataType = ck::bhalf_t;
using OutDataType = ck::bhalf_t;
// Use std tuple instead ck tuple to avoid clang
// Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error.
using DDataTypes = std::tuple<ck::bhalf_t, ck::bhalf_t>;
#include "grouped_conv_fwd_scaleaddx2_relu.inc"
#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
......@@ -8,8 +8,10 @@
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using OutDataType = ck::half_t;
// Use std tuple instead ck tuple to avoid clang
// Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error.
using DDataTypes = std::tuple<ck::half_t, ck::half_t>;
#include "grouped_conv_fwd_scaleaddx2_relu.inc"
#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
......@@ -8,8 +8,10 @@
using InDataType = float;
using WeiDataType = float;
using OutDataType = float;
// Use std tuple instead ck tuple to avoid clang
// Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error.
using DDataTypes = std::tuple<float, float>;
#include "grouped_conv_fwd_scaleaddx2_relu.inc"
#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
......@@ -8,8 +8,10 @@
using InDataType = int8_t;
using WeiDataType = int8_t;
using OutDataType = int8_t;
// Use std tuple instead ck tuple to avoid clang
// Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error.
using DDataTypes = std::tuple<float, float>;
#include "grouped_conv_fwd_scaleaddx2_relu.inc"
#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
add_executable(client_grouped_convnd_fwd_scaleaddx2_relu_fp32 grouped_conv_fwd_scaleaddx2_relu_fp32.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleaddx2_relu_fp32 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleaddx2_relu_fp16 grouped_conv_fwd_scaleaddx2_relu_fp16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleaddx2_relu_fp16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleaddx2_relu_bf16 grouped_conv_fwd_scaleaddx2_relu_bf16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleaddx2_relu_bf16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_convnd_fwd_scaleaddx2_relu_int8 grouped_conv_fwd_scaleaddx2_relu_int8.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleaddx2_relu_int8 PRIVATE composable_kernel::device_operations)
......@@ -31,8 +31,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16)
# ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleaddx2_relu_fp16 convnd_fwd_xdl_scaleaddx2_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleaddx2_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
set(target 1)
endif()
endforeach()
......@@ -190,9 +190,8 @@ bool run_grouped_conv_fwd(bool do_verification,
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
throw std::runtime_error("The device op with the specified compilation parameters does "
"not support this convolution problem.");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
......@@ -97,7 +97,7 @@ using DeviceGroupedConvNDFwdInstance =
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
namespace {
// Use own implementation to pass two more tensors for post op
// Use custom implementation to pass two more tensors for post op
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
......@@ -181,7 +181,7 @@ bool run_grouped_conv_fwd(bool do_verification,
copy(conv_param.input_right_pads_, input_right_pads);
const std::array<const void*, NumDs> ds = {d0_buf.GetDeviceBuffer(), d1_buf.GetDeviceBuffer()};
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
......@@ -208,9 +208,8 @@ bool run_grouped_conv_fwd(bool do_verification,
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
throw std::runtime_error("The device op with the specified compilation parameters does "
"not support this convolution problem.");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
......@@ -59,11 +59,11 @@ struct ReferenceConvFwd : public device::BaseOperator
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<OutDataType>, NumDTensor>& postop_tensors)
const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors)
: input_{input},
weight_{weight},
output_{output},
postop_tensors_{postop_tensors},
d_tensors_{d_tensors},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
......@@ -78,7 +78,7 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<WeiDataType>& weight_;
Tensor<OutDataType>& output_;
const std::array<Tensor<OutDataType>, NumDTensor>& postop_tensors_;
const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
......@@ -141,19 +141,18 @@ struct ReferenceConvFwd : public device::BaseOperator
}
else if constexpr(NumDTensor == 1)
{
arg.out_element_op_(
v_out, v_acc_converted, arg.postop_tensors_[0](g, n, k, wo));
arg.out_element_op_(v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, wo));
}
else if constexpr(NumDTensor == 2)
{
arg.out_element_op_(v_out,
v_acc_converted,
arg.postop_tensors_[0](g, n, k, wo),
arg.postop_tensors_[1](g, n, k, wo));
arg.d_tensors_[0](g, n, k, wo),
arg.d_tensors_[1](g, n, k, wo));
}
else
{
throw std::runtime_error("ElementOp not supported in reference.");
throw std::runtime_error("Output ElementOp not supported in reference.");
}
arg.output_(g, n, k, wo) = v_out;
};
......@@ -216,18 +215,18 @@ struct ReferenceConvFwd : public device::BaseOperator
else if constexpr(NumDTensor == 1)
{
arg.out_element_op_(
v_out, v_acc_converted, arg.postop_tensors_[0](g, n, k, ho, wo));
v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, ho, wo));
}
else if constexpr(NumDTensor == 2)
{
arg.out_element_op_(v_out,
v_acc_converted,
arg.postop_tensors_[0](g, n, k, ho, wo),
arg.postop_tensors_[1](g, n, k, ho, wo));
arg.d_tensors_[0](g, n, k, ho, wo),
arg.d_tensors_[1](g, n, k, ho, wo));
}
else
{
throw std::runtime_error("ElementOp not supported in reference.");
throw std::runtime_error("Output ElementOp not supported in reference.");
}
arg.output_(g, n, k, ho, wo) = v_out;
};
......@@ -303,18 +302,18 @@ struct ReferenceConvFwd : public device::BaseOperator
else if constexpr(NumDTensor == 1)
{
arg.out_element_op_(
v_out, v_acc_converted, arg.postop_tensors_[0](g, n, k, d_o, ho, wo));
v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, d_o, ho, wo));
}
else if constexpr(NumDTensor == 2)
{
arg.out_element_op_(v_out,
v_acc_converted,
arg.postop_tensors_[0](g, n, k, d_o, ho, wo),
arg.postop_tensors_[1](g, n, k, d_o, ho, wo));
arg.d_tensors_[0](g, n, k, d_o, ho, wo),
arg.d_tensors_[1](g, n, k, d_o, ho, wo));
}
else
{
throw std::runtime_error("ElementOp not supported in reference.");
throw std::runtime_error("Output ElementOp not supported in reference.");
}
arg.output_(g, n, k, d_o, ho, wo) = v_out;
};
......@@ -360,7 +359,7 @@ struct ReferenceConvFwd : public device::BaseOperator
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<OutDataType>, NumDTensor>& postop_tensors = {})
const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors = {})
{
return Argument{input,
weight,
......@@ -372,7 +371,7 @@ struct ReferenceConvFwd : public device::BaseOperator
in_element_op,
wei_element_op,
out_element_op,
postop_tensors};
d_tensors};
}
static auto MakeInvoker() { return Invoker{}; }
......
......@@ -43,7 +43,7 @@ template <index_t NDimSpatial,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_scaleaddx2_relu_bf16_instances = std::tuple<
using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
......@@ -65,7 +65,7 @@ template <index_t NDimSpatial,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_scaleaddx2_relu_f16_instances = std::tuple<
using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
......@@ -87,7 +87,7 @@ template <index_t NDimSpatial,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_scaleaddx2_relu_f32_instances = std::tuple<
using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
......@@ -109,7 +109,7 @@ template <index_t NDimSpatial,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_scaleaddx2_relu_int8_instances = std::tuple<
using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
......
......@@ -5,6 +5,7 @@
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
......@@ -22,7 +23,7 @@ using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAd
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
......@@ -38,7 +39,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
......@@ -54,7 +55,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances(
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
......@@ -70,7 +71,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances(
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
......@@ -135,7 +136,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances(
add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
#endif
......@@ -143,7 +144,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
......@@ -151,7 +152,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
}
#endif
......@@ -159,7 +160,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances(
add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances(
op_ptrs);
}
#endif
......
set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp)
add_instance_library(device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance ${GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU})
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleaddx2_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
......@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
......@@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_bf16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_bf16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1P0>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_bf16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleaddx2_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
......@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
......@@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1P0>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleaddx2_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
......@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances(
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
......@@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1P0>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleaddx2_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances(
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
......@@ -24,28 +24,28 @@ void add_device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_int8_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_int8_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1P0>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_scaleaddx2_relu_int8_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0>{});
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
......
set(GROUPED_CONV3D_FWD_SCALEADDX2_RELU
xdl/device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleaddx2_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp)
add_instance_library(device_grouped_conv3d_fwd_scaleaddx2_relu_instance ${GROUPED_CONV3D_FWD_SCALEADDX2_RELU})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment