Commit c6891e12 authored by rocking's avatar rocking
Browse files

Merge branch 'develop' into standalone-layernorm

parents f591ad27 8e374781
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -36,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); ...@@ -36,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); ...@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -48,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1); ...@@ -48,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1); ...@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -48,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1); ...@@ -48,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -21,7 +21,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); ...@@ -21,7 +21,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
// clang-format on // clang-format on
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -36,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); ...@@ -36,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -22,6 +22,7 @@ set(PROFILER_SOURCE ...@@ -22,6 +22,7 @@ set(PROFILER_SOURCE
src/profile_conv_bwd_weight.cpp src/profile_conv_bwd_weight.cpp
src/profile_batched_gemm_reduce.cpp src/profile_batched_gemm_reduce.cpp
src/profile_gemm_add_add_fastgelu.cpp src/profile_gemm_add_add_fastgelu.cpp
src/profile_normalization.cpp
) )
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
...@@ -46,4 +47,5 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) ...@@ -46,4 +47,5 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_normalization_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/device_memory.hpp"
...@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
int M, int M,
int N, int N,
int K, int K,
int BatchStrideA,
int BatchStrideB,
int BatchStrideC,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideC, int StrideC,
...@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride,
auto layout) { auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1})); std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride})); std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); Tensor<ADataType> a_g_m_k(
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
Tensor<CDataType> c_g_m_n_host_result( Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
...@@ -116,19 +122,21 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -116,19 +122,21 @@ bool profile_batched_gemm_impl(int do_verification,
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
c_device_buf.ToDevice(c_g_m_n_device_result.mData.data()); c_device_buf.ToDevice(c_g_m_n_device_result.mData.data());
// add device op instances using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm<ALayout,
const auto op_ptrs = ck::tensor_operation::device::device_batched_gemm_instance:: BLayout,
get_device_batched_gemm_instances<ADataType, CLayout,
BDataType, ADataType,
CDataType, BDataType,
ALayout, CDataType,
BLayout, AElementOp,
CLayout>(); BElementOp,
CElementOp>;
if(op_ptrs.size() <= 0) // get device op instances
{ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
throw std::runtime_error("wrong! no device GEMM instance found"); DeviceOp>::GetInstances();
}
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name; std::string best_op_name;
float best_ave_time = 0; float best_ave_time = 0;
...@@ -148,6 +156,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -148,6 +156,9 @@ bool profile_batched_gemm_impl(int do_verification,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -44,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn ...@@ -44,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&); std::vector<DeviceGemmReduceNoOpPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -208,8 +208,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -208,8 +208,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmReduceNoOpPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -218,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -218,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -226,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -226,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -234,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -234,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -242,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -242,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_bwd_weight_instance { namespace instance {
using DeviceConvBwdWeightNoOpPtr = using DeviceConvBwdWeightNoOpPtr =
DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -31,7 +31,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( ...@@ -31,7 +31,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvBwdWeightNoOpPtr>&); std::vector<DeviceConvBwdWeightNoOpPtr>&);
} // namespace device_conv2d_bwd_weight_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -165,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -165,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_weight_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_weight_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_fwd_bias_activation_add_instance { namespace instance {
using DeviceConvFwdBiasReluAddPtr = using DeviceConvFwdBiasReluAddPtr =
DeviceConvFwdBiasActivationAddPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvFwdBiasActivationAddPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluAddPtr = ...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluAddPtr =
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasReluAddPtr>&); std::vector<DeviceConvFwdBiasReluAddPtr>&);
} // namespace device_conv2d_fwd_bias_activation_add_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -179,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, ...@@ -179,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_add_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_fwd_bias_activation_instance { namespace instance {
using DeviceConvFwdBiasReluPtr = using DeviceConvFwdBiasReluPtr =
DeviceConvFwdBiasActivationPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvFwdBiasActivationPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluPtr = ...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluPtr =
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasReluPtr>&); std::vector<DeviceConvFwdBiasReluPtr>&);
} // namespace device_conv2d_fwd_bias_activation_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -169,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -169,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
......
...@@ -22,7 +22,7 @@ using INT8 = int8_t; ...@@ -22,7 +22,7 @@ using INT8 = int8_t;
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_bwd_data_instance { namespace instance {
using DeviceConvBwdDataNoOpPtr = using DeviceConvBwdDataNoOpPtr =
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( ...@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&); std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&); std::vector<DeviceConvBwdDataNoOpPtr>&);
} // namespace device_conv2d_bwd_data_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
namespace ck { namespace ck {
namespace profiler { namespace profiler {
using DeviceConvBwdDataNoOpPtr = using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr;
ck::tensor_operation::device::device_conv2d_bwd_data_instance::DeviceConvBwdDataNoOpPtr;
template <typename InLayout> template <typename InLayout>
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims, HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
...@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
break; break;
default: break; default: break;
...@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
break; break;
default: break; default: break;
...@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
break; break;
default: break; default: break;
...@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
break; break;
default: break; default: break;
......
...@@ -10,13 +10,12 @@ ...@@ -10,13 +10,12 @@
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/host_tensor/host_conv.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck { namespace ck {
...@@ -30,9 +29,7 @@ template <typename ADataType, ...@@ -30,9 +29,7 @@ template <typename ADataType,
typename EDataType, typename EDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename D0Layout, typename DELayout> // assume Ds and E have same layout
typename D1Layout,
typename ELayout>
bool profile_gemm_add_add_fastgelu_impl(int do_verification, bool profile_gemm_add_add_fastgelu_impl(int do_verification,
int init_method, int init_method,
bool /*do_log*/, bool /*do_log*/,
...@@ -62,10 +59,10 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -62,10 +59,10 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, DELayout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, DELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
...@@ -100,19 +97,21 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -100,19 +97,21 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{}; const auto cde_element_op = CDEElementOp{};
// add device op instances using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD<
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance:: ALayout,
get_device_gemm_add_add_fastgelu_instances<ADataType, BLayout,
BDataType, DELayout,
AccDataType, ADataType,
D0DataType, BDataType,
D1DataType, ck::Tuple<D0DataType, D1DataType>,
EDataType, EDataType,
ALayout, ck::tensor_operation::element_wise::PassThrough,
BLayout, ck::tensor_operation::element_wise::PassThrough,
D0Layout, ck::tensor_operation::element_wise::AddAddFastGelu>;
D1Layout,
ELayout>(); // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using DeviceGemmAlphaBetaPtr = ck::tensor_operation::device::DeviceGemmBiasPtr< using DeviceGemmAlphaBetaPtr = ck::tensor_operation::device::DeviceGemmBiasPtr<
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -48,7 +48,7 @@ void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances( ...@@ -48,7 +48,7 @@ void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&); std::vector<DeviceGemmAlphaBetaPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -159,8 +159,7 @@ void profile_gemm_bias_2d_impl(int do_verification, ...@@ -159,8 +159,7 @@ void profile_gemm_bias_2d_impl(int do_verification,
c_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmAlphaBetaPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmAlphaBetaPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -169,28 +168,28 @@ void profile_gemm_bias_2d_impl(int do_verification, ...@@ -169,28 +168,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
} }
} }
...@@ -201,28 +200,28 @@ void profile_gemm_bias_2d_impl(int do_verification, ...@@ -201,28 +200,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
} }
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -45,7 +45,7 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f ...@@ -45,7 +45,7 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&); std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -236,8 +236,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -236,8 +236,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
d0_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasAddReduceNoOpPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasAddReduceNoOpPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -246,7 +245,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -246,7 +245,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -254,7 +253,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -254,7 +253,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -262,7 +261,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -262,7 +261,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -270,7 +269,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -270,7 +269,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using DeviceGemmBiasReluAddPtr = ck::tensor_operation::device::DeviceGemmBiasActivationAddPtr< using DeviceGemmBiasReluAddPtr = ck::tensor_operation::device::DeviceGemmBiasActivationAddPtr<
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( ...@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmBiasReluAddPtr>&); std::vector<DeviceGemmBiasReluAddPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -158,8 +158,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -158,8 +158,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasReluAddPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasReluAddPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -168,7 +167,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -168,7 +167,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -176,7 +175,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -176,7 +175,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -184,7 +183,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -184,7 +183,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -192,7 +191,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -192,7 +191,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using DeviceGemmBiasReluPtr = ck::tensor_operation::device::DeviceGemmBiasActivationPtr< using DeviceGemmBiasReluPtr = ck::tensor_operation::device::DeviceGemmBiasActivationPtr<
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances( ...@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmBiasReluPtr>&); std::vector<DeviceGemmBiasReluPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -144,8 +144,7 @@ void profile_gemm_bias_relu_impl(int do_verification, ...@@ -144,8 +144,7 @@ void profile_gemm_bias_relu_impl(int do_verification,
c0_n_device_buf.ToDevice(c0_n.mData.data()); c0_n_device_buf.ToDevice(c0_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasReluPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasReluPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -154,28 +153,28 @@ void profile_gemm_bias_relu_impl(int do_verification, ...@@ -154,28 +153,28 @@ void profile_gemm_bias_relu_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
} }
} }
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/device_memory.hpp"
...@@ -94,14 +94,21 @@ int profile_gemm_impl(int do_verification, ...@@ -94,14 +94,21 @@ int profile_gemm_impl(int do_verification,
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device op instances using DeviceOp = ck::tensor_operation::device::DeviceGemm<ALayout,
const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance:: BLayout,
get_device_gemm_instances<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(); CLayout,
ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
if(op_ptrs.size() <= 0) // get device op instances
{ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
throw std::runtime_error("wrong! no device GEMM instance found"); DeviceOp>::GetInstances();
}
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM // Run reference GEMM
if(do_verification) if(do_verification)
...@@ -141,9 +148,9 @@ int profile_gemm_impl(int do_verification, ...@@ -141,9 +148,9 @@ int profile_gemm_impl(int do_verification,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
ck::tensor_operation::element_wise::PassThrough{}, a_element_op,
ck::tensor_operation::element_wise::PassThrough{}, b_element_op,
ck::tensor_operation::element_wise::PassThrough{}); c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment