Unverified Commit 0dcb3496 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Improve external interface for GEMM and GEMM+add+add+fastgelu (#311)

* interface for GEMM and GEMM+add+add+fastgelu

* rename namespace

* instance factory

* fix build

* fix build; add GEMM client example

* clean
parent fa9a0a5c
......@@ -19,7 +19,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
namespace instance {
using F32 = float;
using F16 = ck::half_t;
......@@ -44,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -208,8 +208,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
b_device_buf.ToDevice(b_g_k_n.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmReduceNoOpPtr>
gemm_ptrs;
std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
......@@ -218,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
gemm_ptrs);
}
......@@ -226,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
gemm_ptrs);
}
......@@ -234,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
gemm_ptrs);
}
......@@ -242,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
gemm_ptrs);
}
......
......@@ -18,7 +18,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_bwd_weight_instance {
namespace instance {
using DeviceConvBwdWeightNoOpPtr =
DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough,
......@@ -31,7 +31,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvBwdWeightNoOpPtr>&);
} // namespace device_conv2d_bwd_weight_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -165,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
ck::tensor_operation::device::device_conv2d_bwd_weight_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
}
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{
ck::tensor_operation::device::device_conv2d_bwd_weight_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
......
......@@ -17,7 +17,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_bias_activation_add_instance {
namespace instance {
using DeviceConvFwdBiasReluAddPtr =
DeviceConvFwdBiasActivationAddPtr<ck::tensor_operation::element_wise::PassThrough,
......@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluAddPtr =
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasReluAddPtr>&);
} // namespace device_conv2d_fwd_bias_activation_add_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -179,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_add_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
......
......@@ -17,7 +17,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_bias_activation_instance {
namespace instance {
using DeviceConvFwdBiasReluPtr =
DeviceConvFwdBiasActivationPtr<ck::tensor_operation::element_wise::PassThrough,
......@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluPtr =
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasReluPtr>&);
} // namespace device_conv2d_fwd_bias_activation_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -169,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
......
......@@ -22,7 +22,7 @@ using INT8 = int8_t;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_bwd_data_instance {
namespace instance {
using DeviceConvBwdDataNoOpPtr =
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
......@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
} // namespace device_conv2d_bwd_data_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace profiler {
using DeviceConvBwdDataNoOpPtr =
ck::tensor_operation::device::device_conv2d_bwd_data_instance::DeviceConvBwdDataNoOpPtr;
using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr;
template <typename InLayout>
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
......@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
break;
default: break;
......@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
break;
default: break;
......@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
break;
default: break;
......@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial)
{
case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
break;
case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
break;
case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
break;
default: break;
......
......@@ -10,13 +10,12 @@
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/host_tensor/host_conv.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
......@@ -30,9 +29,7 @@ template <typename ADataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename ELayout>
typename DELayout> // assume Ds and E have same layout
bool profile_gemm_add_add_fastgelu_impl(int do_verification,
int init_method,
bool /*do_log*/,
......@@ -62,10 +59,10 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, DELayout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, DELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
......@@ -100,19 +97,21 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{};
// add device op instances
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance::
get_device_gemm_add_add_fastgelu_instances<ADataType,
BDataType,
AccDataType,
D0DataType,
D1DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
D1Layout,
ELayout>();
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddAddFastGelu>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
......@@ -17,7 +17,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
namespace instance {
using DeviceGemmAlphaBetaPtr = ck::tensor_operation::device::DeviceGemmBiasPtr<
ck::tensor_operation::element_wise::PassThrough,
......@@ -48,7 +48,7 @@ void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&);
} // namespace device_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -159,8 +159,7 @@ void profile_gemm_bias_2d_impl(int do_verification,
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmAlphaBetaPtr>
gemm_ptrs;
std::vector<ck::tensor_operation::device::instance::DeviceGemmAlphaBetaPtr> gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
......@@ -169,28 +168,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
......@@ -201,28 +200,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
}
}
......
......@@ -19,7 +19,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
namespace instance {
using F32 = float;
using F16 = ck::half_t;
......@@ -45,7 +45,7 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -236,8 +236,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
d0_device_buf.ToDevice(d0_m_n.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasAddReduceNoOpPtr>
gemm_ptrs;
std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasAddReduceNoOpPtr> gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
......@@ -246,7 +245,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs);
}
......@@ -254,7 +253,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs);
}
......@@ -262,7 +261,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs);
}
......@@ -270,7 +269,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs);
}
......
......@@ -18,7 +18,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
namespace instance {
using DeviceGemmBiasReluAddPtr = ck::tensor_operation::device::DeviceGemmBiasActivationAddPtr<
ck::tensor_operation::element_wise::PassThrough,
......@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmBiasReluAddPtr>&);
} // namespace device_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -158,8 +158,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasReluAddPtr>
gemm_ptrs;
std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasReluAddPtr> gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
......@@ -168,7 +167,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances(
gemm_ptrs);
}
......@@ -176,7 +175,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances(
gemm_ptrs);
}
......@@ -184,7 +183,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
gemm_ptrs);
}
......@@ -192,7 +191,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances(
gemm_ptrs);
}
......
......@@ -18,7 +18,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
namespace instance {
using DeviceGemmBiasReluPtr = ck::tensor_operation::device::DeviceGemmBiasActivationPtr<
ck::tensor_operation::element_wise::PassThrough,
......@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmBiasReluPtr>&);
} // namespace device_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -144,8 +144,7 @@ void profile_gemm_bias_relu_impl(int do_verification,
c0_n_device_buf.ToDevice(c0_n.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasReluPtr>
gemm_ptrs;
std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasReluPtr> gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
......@@ -154,28 +153,28 @@ void profile_gemm_bias_relu_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
......
......@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
......@@ -94,14 +94,21 @@ int profile_gemm_impl(int do_verification,
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device op instances
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance::
get_device_gemm_instances<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>();
using DeviceOp = ck::tensor_operation::device::DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
......@@ -141,9 +148,9 @@ int profile_gemm_impl(int do_verification,
StrideA,
StrideB,
StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
......@@ -19,7 +19,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
namespace instance {
using F32 = float;
using F16 = ck::half_t;
......@@ -45,7 +45,7 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -204,8 +204,7 @@ bool profile_gemm_reduce_impl(int do_verification,
b_device_buf.ToDevice(b_k_n.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmReduceNoOpPtr>
gemm_ptrs;
std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
......@@ -214,7 +213,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs);
}
......@@ -222,7 +221,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs);
}
......@@ -230,7 +229,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs);
}
......@@ -238,7 +237,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs);
}
......
......@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_splitk_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
......@@ -95,20 +95,21 @@ bool profile_gemm_splitk_impl(int do_verification,
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device op instances
const auto op_ptrs =
ck::tensor_operation::device::device_gemm_instance::get_device_gemm_splitk_instances<
ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout>();
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device operation instance found");
}
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
......
......@@ -20,7 +20,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_grouped_gemm_instance {
namespace instance {
using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr<
ck::tensor_operation::element_wise::PassThrough,
......@@ -36,7 +36,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
} // namespace device_grouped_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -171,9 +171,7 @@ void profile_grouped_gemm_impl(int do_verification,
}
// add device GEMM instances
std::vector<
ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr>
gemm_ptrs;
std::vector<ck::tensor_operation::device::instance::DeviceGroupedGemmNoOpPtr> gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
......@@ -182,28 +180,28 @@ void profile_grouped_gemm_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_grouped_gemm_instance::
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_grouped_gemm_instance::
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_grouped_gemm_instance::
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_grouped_gemm_instance::
ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
......
......@@ -18,7 +18,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_normalization_instance {
namespace instance {
void add_device_softmax_f16_f16_rank3_instances(std::vector<DeviceNormalizationPtr>&);
void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationPtr>&);
......@@ -26,7 +26,7 @@ void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationP
void add_device_softmax_f32_f32_rank3_instances(std::vector<DeviceNormalizationPtr>&);
void add_device_softmax_f32_f32_rank4_instances(std::vector<DeviceNormalizationPtr>&);
} // namespace device_normalization_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -109,23 +109,23 @@ void profile_normalization_impl(int do_verification,
is_same<AccDataType, float>::value)
{
if(in_length.size() == 3)
tensor_operation::device::device_normalization_instance::
add_device_softmax_f16_f16_rank3_instances(instances);
tensor_operation::device::instance::add_device_softmax_f16_f16_rank3_instances(
instances);
if(in_length.size() == 4)
tensor_operation::device::device_normalization_instance::
add_device_softmax_f16_f16_rank4_instances(instances);
tensor_operation::device::instance::add_device_softmax_f16_f16_rank4_instances(
instances);
}
else if constexpr(is_same<InDataType, float>::value && is_same<OutDataType, float>::value &&
is_same<AccDataType, float>::value)
{
if(in_length.size() == 3)
tensor_operation::device::device_normalization_instance::
add_device_softmax_f32_f32_rank3_instances(instances);
tensor_operation::device::instance::add_device_softmax_f32_f32_rank3_instances(
instances);
if(in_length.size() == 4)
tensor_operation::device::device_normalization_instance::
add_device_softmax_f32_f32_rank4_instances(instances);
tensor_operation::device::instance::add_device_softmax_f32_f32_rank4_instances(
instances);
}
}
......
......@@ -16,7 +16,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
template <int Rank, int NumReduceDim, int ReduceOpId, bool PropagateNan, bool UseIndex>
struct ReduceDescription
......@@ -91,7 +91,7 @@ bool description_match(const DescriptionType& description,
return (result);
};
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -142,7 +142,7 @@ bool profile_reduce_impl_impl(bool do_verification,
float beta)
{
using namespace ck::tensor_operation::device;
using namespace ck::tensor_operation::device::device_reduce_instance;
using namespace ck::tensor_operation::device::instance;
using ck::host_common::dumpBufferToFile;
constexpr bool op_support_indices =
......@@ -464,7 +464,7 @@ bool profile_reduce_impl(bool do_verification,
bool pass = true;
using tuple_of_description_instances =
tensor_operation::device::device_reduce_instance::reduce_description_instances;
tensor_operation::device::instance::reduce_description_instances;
const auto tuple_object = tuple_of_description_instances{};
......
......@@ -75,9 +75,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
auto e_type,
auto a_layout,
auto b_layout,
auto d0_layout,
auto d1_layout,
auto e_layout) {
auto de_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
......@@ -87,15 +85,13 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using D0Layout = decltype(d0_layout);
using D1Layout = decltype(d1_layout);
using ELayout = decltype(e_layout);
using DELayout = decltype(de_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
const int DefaultStrideD0 = ck::is_same_v<DELayout, Row> ? N : M;
const int DefaultStrideD1 = ck::is_same_v<DELayout, Row> ? N : M;
const int DefaultStrideE = ck::is_same_v<DELayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_add_add_fastgelu_impl<ADataType,
BDataType,
......@@ -105,9 +101,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
EDataType,
ALayout,
BLayout,
D0Layout,
D1Layout,
ELayout>(
DELayout>(
do_verification,
init_method,
do_log,
......@@ -126,22 +120,22 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{});
}
else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::MK_NK_MN_MN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::KM_KN_MN_MN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{}, Row{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{});
}
else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::KM_NK_MN_MN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{});
}
else
{
......
WORKSPACE=$1
echo "workspace: " $WORKSPACE
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v $WORKSPACE:/root/workspace \
rocm/tensorflow:rocm4.1-tf1.15-dev \
/bin/bash
#--network host \
WORKSPACE=$1
echo "workspace: " $WORKSPACE
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v $WORKSPACE:/root/workspace \
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
#--network host \
......@@ -20,7 +20,7 @@ using INT8 = int8_t;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_bwd_data_instance {
namespace instance {
using DeviceConvBwdDataNoOpPtr =
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
......@@ -36,7 +36,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&);
} // namespace device_conv2d_bwd_data_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -220,28 +220,28 @@ int main(int argc, char* argv[])
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
}
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::bhalf_t>)
{
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
}
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, int8_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, int8_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, int8_t>)
{
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment