Unverified Commit 0dcb3496 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Improve external interface for GEMM and GEMM+add+add+fastgelu (#311)

* interface for GEMM and GEMM+add+add+fastgelu

* rename namespace

* instance factory

* fix build

* fix build; add GEMM client example

* clean
parent fa9a0a5c
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -44,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn ...@@ -44,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&); std::vector<DeviceGemmReduceNoOpPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -208,8 +208,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -208,8 +208,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmReduceNoOpPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -218,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -218,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -226,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -226,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -234,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -234,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -242,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -242,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_bwd_weight_instance { namespace instance {
using DeviceConvBwdWeightNoOpPtr = using DeviceConvBwdWeightNoOpPtr =
DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -31,7 +31,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( ...@@ -31,7 +31,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvBwdWeightNoOpPtr>&); std::vector<DeviceConvBwdWeightNoOpPtr>&);
} // namespace device_conv2d_bwd_weight_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -165,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -165,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_weight_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_weight_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_fwd_bias_activation_add_instance { namespace instance {
using DeviceConvFwdBiasReluAddPtr = using DeviceConvFwdBiasReluAddPtr =
DeviceConvFwdBiasActivationAddPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvFwdBiasActivationAddPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluAddPtr = ...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluAddPtr =
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasReluAddPtr>&); std::vector<DeviceConvFwdBiasReluAddPtr>&);
} // namespace device_conv2d_fwd_bias_activation_add_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -179,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, ...@@ -179,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_add_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_fwd_bias_activation_instance { namespace instance {
using DeviceConvFwdBiasReluPtr = using DeviceConvFwdBiasReluPtr =
DeviceConvFwdBiasActivationPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvFwdBiasActivationPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluPtr = ...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluPtr =
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasReluPtr>&); std::vector<DeviceConvFwdBiasReluPtr>&);
} // namespace device_conv2d_fwd_bias_activation_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -169,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -169,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
......
...@@ -22,7 +22,7 @@ using INT8 = int8_t; ...@@ -22,7 +22,7 @@ using INT8 = int8_t;
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_bwd_data_instance { namespace instance {
using DeviceConvBwdDataNoOpPtr = using DeviceConvBwdDataNoOpPtr =
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( ...@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&); std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&); std::vector<DeviceConvBwdDataNoOpPtr>&);
} // namespace device_conv2d_bwd_data_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
namespace ck { namespace ck {
namespace profiler { namespace profiler {
using DeviceConvBwdDataNoOpPtr = using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr;
ck::tensor_operation::device::device_conv2d_bwd_data_instance::DeviceConvBwdDataNoOpPtr;
template <typename InLayout> template <typename InLayout>
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims, HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
...@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
break; break;
default: break; default: break;
...@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
break; break;
default: break; default: break;
...@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
break; break;
default: break; default: break;
...@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
break; break;
default: break; default: break;
......
...@@ -10,13 +10,12 @@ ...@@ -10,13 +10,12 @@
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/host_tensor/host_conv.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck { namespace ck {
...@@ -30,9 +29,7 @@ template <typename ADataType, ...@@ -30,9 +29,7 @@ template <typename ADataType,
typename EDataType, typename EDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename D0Layout, typename DELayout> // assume Ds and E have same layout
typename D1Layout,
typename ELayout>
bool profile_gemm_add_add_fastgelu_impl(int do_verification, bool profile_gemm_add_add_fastgelu_impl(int do_verification,
int init_method, int init_method,
bool /*do_log*/, bool /*do_log*/,
...@@ -62,10 +59,10 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -62,10 +59,10 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, DELayout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, DELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
...@@ -100,19 +97,21 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -100,19 +97,21 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{}; const auto cde_element_op = CDEElementOp{};
// add device op instances using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD<
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance:: ALayout,
get_device_gemm_add_add_fastgelu_instances<ADataType, BLayout,
BDataType, DELayout,
AccDataType, ADataType,
D0DataType, BDataType,
D1DataType, ck::Tuple<D0DataType, D1DataType>,
EDataType, EDataType,
ALayout, ck::tensor_operation::element_wise::PassThrough,
BLayout, ck::tensor_operation::element_wise::PassThrough,
D0Layout, ck::tensor_operation::element_wise::AddAddFastGelu>;
D1Layout,
ELayout>(); // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using DeviceGemmAlphaBetaPtr = ck::tensor_operation::device::DeviceGemmBiasPtr< using DeviceGemmAlphaBetaPtr = ck::tensor_operation::device::DeviceGemmBiasPtr<
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -48,7 +48,7 @@ void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances( ...@@ -48,7 +48,7 @@ void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&); std::vector<DeviceGemmAlphaBetaPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -159,8 +159,7 @@ void profile_gemm_bias_2d_impl(int do_verification, ...@@ -159,8 +159,7 @@ void profile_gemm_bias_2d_impl(int do_verification,
c_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmAlphaBetaPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmAlphaBetaPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -169,28 +168,28 @@ void profile_gemm_bias_2d_impl(int do_verification, ...@@ -169,28 +168,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
} }
} }
...@@ -201,28 +200,28 @@ void profile_gemm_bias_2d_impl(int do_verification, ...@@ -201,28 +200,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
} }
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -45,7 +45,7 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f ...@@ -45,7 +45,7 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&); std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -236,8 +236,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -236,8 +236,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
d0_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasAddReduceNoOpPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasAddReduceNoOpPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -246,7 +245,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -246,7 +245,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -254,7 +253,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -254,7 +253,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -262,7 +261,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -262,7 +261,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -270,7 +269,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -270,7 +269,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using DeviceGemmBiasReluAddPtr = ck::tensor_operation::device::DeviceGemmBiasActivationAddPtr< using DeviceGemmBiasReluAddPtr = ck::tensor_operation::device::DeviceGemmBiasActivationAddPtr<
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( ...@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmBiasReluAddPtr>&); std::vector<DeviceGemmBiasReluAddPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -158,8 +158,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -158,8 +158,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasReluAddPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasReluAddPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -168,7 +167,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -168,7 +167,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -176,7 +175,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -176,7 +175,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -184,7 +183,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -184,7 +183,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -192,7 +191,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -192,7 +191,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using DeviceGemmBiasReluPtr = ck::tensor_operation::device::DeviceGemmBiasActivationPtr< using DeviceGemmBiasReluPtr = ck::tensor_operation::device::DeviceGemmBiasActivationPtr<
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances( ...@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmBiasReluPtr>&); std::vector<DeviceGemmBiasReluPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -144,8 +144,7 @@ void profile_gemm_bias_relu_impl(int do_verification, ...@@ -144,8 +144,7 @@ void profile_gemm_bias_relu_impl(int do_verification,
c0_n_device_buf.ToDevice(c0_n.mData.data()); c0_n_device_buf.ToDevice(c0_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasReluPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasReluPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -154,28 +153,28 @@ void profile_gemm_bias_relu_impl(int do_verification, ...@@ -154,28 +153,28 @@ void profile_gemm_bias_relu_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
} }
} }
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/device_memory.hpp"
...@@ -94,14 +94,21 @@ int profile_gemm_impl(int do_verification, ...@@ -94,14 +94,21 @@ int profile_gemm_impl(int do_verification,
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device op instances using DeviceOp = ck::tensor_operation::device::DeviceGemm<ALayout,
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance:: BLayout,
get_device_gemm_instances<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(); CLayout,
ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
if(op_ptrs.size() <= 0) // get device op instances
{ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
throw std::runtime_error("wrong! no device GEMM instance found"); DeviceOp>::GetInstances();
}
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM // Run reference GEMM
if(do_verification) if(do_verification)
...@@ -141,9 +148,9 @@ int profile_gemm_impl(int do_verification, ...@@ -141,9 +148,9 @@ int profile_gemm_impl(int do_verification,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
ck::tensor_operation::element_wise::PassThrough{}, a_element_op,
ck::tensor_operation::element_wise::PassThrough{}, b_element_op,
ck::tensor_operation::element_wise::PassThrough{}); c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -45,7 +45,7 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( ...@@ -45,7 +45,7 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&); std::vector<DeviceGemmReduceNoOpPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -204,8 +204,7 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -204,8 +204,7 @@ bool profile_gemm_reduce_impl(int do_verification,
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmReduceNoOpPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -214,7 +213,7 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -214,7 +213,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -222,7 +221,7 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -222,7 +221,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -230,7 +229,7 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -230,7 +229,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -238,7 +237,7 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -238,7 +237,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_splitk_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/device_memory.hpp"
...@@ -95,20 +95,21 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -95,20 +95,21 @@ bool profile_gemm_splitk_impl(int do_verification,
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device op instances using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<ALayout,
const auto op_ptrs = BLayout,
ck::tensor_operation::device::device_gemm_instance::get_device_gemm_splitk_instances< CLayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
ALayout, AElementOp,
BLayout, BElementOp,
CLayout>(); CElementOp>;
if(op_ptrs.size() <= 0) // get device op instances
{ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
throw std::runtime_error("wrong! no device operation instance found"); DeviceOp>::GetInstances();
}
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM // Run reference GEMM
if(do_verification) if(do_verification)
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_grouped_gemm_instance { namespace instance {
using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr< using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr<
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -36,7 +36,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( ...@@ -36,7 +36,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&); std::vector<DeviceGroupedGemmNoOpPtr>&);
} // namespace device_grouped_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -171,9 +171,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -171,9 +171,7 @@ void profile_grouped_gemm_impl(int do_verification,
} }
// add device GEMM instances // add device GEMM instances
std::vector< std::vector<ck::tensor_operation::device::instance::DeviceGroupedGemmNoOpPtr> gemm_ptrs;
ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr>
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -182,28 +180,28 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -182,28 +180,28 @@ void profile_grouped_gemm_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_grouped_gemm_instance:: ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_grouped_gemm_instance:: ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_grouped_gemm_instance:: ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_grouped_gemm_instance:: ck::tensor_operation::device::instance::
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
} }
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_normalization_instance { namespace instance {
void add_device_softmax_f16_f16_rank3_instances(std::vector<DeviceNormalizationPtr>&); void add_device_softmax_f16_f16_rank3_instances(std::vector<DeviceNormalizationPtr>&);
void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationPtr>&); void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationPtr>&);
...@@ -26,7 +26,7 @@ void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationP ...@@ -26,7 +26,7 @@ void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationP
void add_device_softmax_f32_f32_rank3_instances(std::vector<DeviceNormalizationPtr>&); void add_device_softmax_f32_f32_rank3_instances(std::vector<DeviceNormalizationPtr>&);
void add_device_softmax_f32_f32_rank4_instances(std::vector<DeviceNormalizationPtr>&); void add_device_softmax_f32_f32_rank4_instances(std::vector<DeviceNormalizationPtr>&);
} // namespace device_normalization_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -109,23 +109,23 @@ void profile_normalization_impl(int do_verification, ...@@ -109,23 +109,23 @@ void profile_normalization_impl(int do_verification,
is_same<AccDataType, float>::value) is_same<AccDataType, float>::value)
{ {
if(in_length.size() == 3) if(in_length.size() == 3)
tensor_operation::device::device_normalization_instance:: tensor_operation::device::instance::add_device_softmax_f16_f16_rank3_instances(
add_device_softmax_f16_f16_rank3_instances(instances); instances);
if(in_length.size() == 4) if(in_length.size() == 4)
tensor_operation::device::device_normalization_instance:: tensor_operation::device::instance::add_device_softmax_f16_f16_rank4_instances(
add_device_softmax_f16_f16_rank4_instances(instances); instances);
} }
else if constexpr(is_same<InDataType, float>::value && is_same<OutDataType, float>::value && else if constexpr(is_same<InDataType, float>::value && is_same<OutDataType, float>::value &&
is_same<AccDataType, float>::value) is_same<AccDataType, float>::value)
{ {
if(in_length.size() == 3) if(in_length.size() == 3)
tensor_operation::device::device_normalization_instance:: tensor_operation::device::instance::add_device_softmax_f32_f32_rank3_instances(
add_device_softmax_f32_f32_rank3_instances(instances); instances);
if(in_length.size() == 4) if(in_length.size() == 4)
tensor_operation::device::device_normalization_instance:: tensor_operation::device::instance::add_device_softmax_f32_f32_rank4_instances(
add_device_softmax_f32_f32_rank4_instances(instances); instances);
} }
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
template <int Rank, int NumReduceDim, int ReduceOpId, bool PropagateNan, bool UseIndex> template <int Rank, int NumReduceDim, int ReduceOpId, bool PropagateNan, bool UseIndex>
struct ReduceDescription struct ReduceDescription
...@@ -91,7 +91,7 @@ bool description_match(const DescriptionType& description, ...@@ -91,7 +91,7 @@ bool description_match(const DescriptionType& description,
return (result); return (result);
}; };
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -142,7 +142,7 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -142,7 +142,7 @@ bool profile_reduce_impl_impl(bool do_verification,
float beta) float beta)
{ {
using namespace ck::tensor_operation::device; using namespace ck::tensor_operation::device;
using namespace ck::tensor_operation::device::device_reduce_instance; using namespace ck::tensor_operation::device::instance;
using ck::host_common::dumpBufferToFile; using ck::host_common::dumpBufferToFile;
constexpr bool op_support_indices = constexpr bool op_support_indices =
...@@ -464,7 +464,7 @@ bool profile_reduce_impl(bool do_verification, ...@@ -464,7 +464,7 @@ bool profile_reduce_impl(bool do_verification,
bool pass = true; bool pass = true;
using tuple_of_description_instances = using tuple_of_description_instances =
tensor_operation::device::device_reduce_instance::reduce_description_instances; tensor_operation::device::instance::reduce_description_instances;
const auto tuple_object = tuple_of_description_instances{}; const auto tuple_object = tuple_of_description_instances{};
......
...@@ -75,9 +75,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -75,9 +75,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
auto e_type, auto e_type,
auto a_layout, auto a_layout,
auto b_layout, auto b_layout,
auto d0_layout, auto de_layout) {
auto d1_layout,
auto e_layout) {
using ADataType = decltype(a_type); using ADataType = decltype(a_type);
using BDataType = decltype(b_type); using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type); using AccDataType = decltype(acc_type);
...@@ -87,15 +85,13 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -87,15 +85,13 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
using ALayout = decltype(a_layout); using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout); using BLayout = decltype(b_layout);
using D0Layout = decltype(d0_layout); using DELayout = decltype(de_layout);
using D1Layout = decltype(d1_layout);
using ELayout = decltype(e_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K; const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M; const int DefaultStrideD0 = ck::is_same_v<DELayout, Row> ? N : M;
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? N : M; const int DefaultStrideD1 = ck::is_same_v<DELayout, Row> ? N : M;
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M; const int DefaultStrideE = ck::is_same_v<DELayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_add_add_fastgelu_impl<ADataType, bool pass = ck::profiler::profile_gemm_add_add_fastgelu_impl<ADataType,
BDataType, BDataType,
...@@ -105,9 +101,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -105,9 +101,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
EDataType, EDataType,
ALayout, ALayout,
BLayout, BLayout,
D0Layout, DELayout>(
D1Layout,
ELayout>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -126,22 +120,22 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -126,22 +120,22 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{});
} }
else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::MK_NK_MN_MN_MN) layout == MatrixLayout::MK_NK_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{});
} }
else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::KM_KN_MN_MN_MN) layout == MatrixLayout::KM_KN_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{});
} }
else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::KM_NK_MN_MN_MN) layout == MatrixLayout::KM_NK_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{});
} }
else else
{ {
......
WORKSPACE=$1
echo "workspace: " $WORKSPACE
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v $WORKSPACE:/root/workspace \
rocm/tensorflow:rocm4.1-tf1.15-dev \
/bin/bash
#--network host \
WORKSPACE=$1
echo "workspace: " $WORKSPACE
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v $WORKSPACE:/root/workspace \
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
#--network host \
...@@ -20,7 +20,7 @@ using INT8 = int8_t; ...@@ -20,7 +20,7 @@ using INT8 = int8_t;
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_bwd_data_instance { namespace instance {
using DeviceConvBwdDataNoOpPtr = using DeviceConvBwdDataNoOpPtr =
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -36,7 +36,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( ...@@ -36,7 +36,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&); std::vector<DeviceConvBwdDataNoOpPtr>&);
} // namespace device_conv2d_bwd_data_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -220,28 +220,28 @@ int main(int argc, char* argv[]) ...@@ -220,28 +220,28 @@ int main(int argc, char* argv[])
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::bhalf_t> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::bhalf_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::bhalf_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::bhalf_t>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, int8_t> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, int8_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, int8_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, int8_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, int8_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, int8_t>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment