Commit 07a673c6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents c0f698d5 ac0d8066
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple<
......
...@@ -20,8 +20,8 @@ using S = ck::Sequence<Is...>; ...@@ -20,8 +20,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple<
......
...@@ -15,6 +15,7 @@ include_directories(BEFORE ...@@ -15,6 +15,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu
${PROJECT_SOURCE_DIR}/library/include/ck/library/utility
${PROJECT_SOURCE_DIR}/profiler/include ${PROJECT_SOURCE_DIR}/profiler/include
${PROJECT_SOURCE_DIR}/external/include/half ${PROJECT_SOURCE_DIR}/external/include/half
) )
...@@ -36,6 +37,7 @@ set(PROFILER_SOURCE ...@@ -36,6 +37,7 @@ set(PROFILER_SOURCE
src/profile_convnd_bwd_data.cpp src/profile_convnd_bwd_data.cpp
src/profile_reduce.cpp src/profile_reduce.cpp
src/profile_grouped_gemm.cpp src/profile_grouped_gemm.cpp
src/profile_conv_bwd_weight.cpp
src/profile_batched_gemm_reduce.cpp src/profile_batched_gemm_reduce.cpp
) )
...@@ -57,4 +59,5 @@ target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) ...@@ -57,4 +59,5 @@ target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
## Docker script
```bash
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
```
## Build ```ckProfiler```
```bash
mkdir build && cd build
```
```bash
# Need to Specify target ID, example below is gfx908
cmake \
-D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
..
```
```bash
make -j ckProfiler
```
## Profile GEMM kernels ## Profile GEMM kernels
```bash ```bash
#arg1: tensor operation (gemm=GEMM) #arg1: tensor operation (gemm=GEMM)
...@@ -42,8 +9,8 @@ cmake \ ...@@ -42,8 +9,8 @@ cmake \
#arg7: run kernel # of times (>1) #arg7: run kernel # of times (>1)
#arg8 to 13: M, N, K, StrideA, StrideB, StrideC #arg8 to 13: M, N, K, StrideA, StrideB, StrideC
##################### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC ################ op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC
./profiler/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096 ./bin/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
...@@ -55,7 +22,7 @@ c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} ...@@ -55,7 +22,7 @@ c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
``` ```
## Profile forward convolution kernels ## Profile 2d forward convolution kernels
```bash ```bash
#arg1: tensor operation (conv=Convolution) #arg1: tensor operation (conv=Convolution)
#arg2: data type (0=fp32, 1=fp16) #arg2: data type (0=fp32, 1=fp16)
...@@ -67,8 +34,8 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s ...@@ -67,8 +34,8 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
#arg8: print matrix value (0=no, 1=yes) #arg8: print matrix value (0=no, 1=yes)
#arg9: run kernel # of times (>1) #arg9: run kernel # of times (>1)
#arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx #arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
##################### op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads ################ op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads
./profiler/ckProfiler conv_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 ./bin/ckProfiler conv2d_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <memory> #include <memory>
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
...@@ -69,7 +70,7 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -69,7 +70,7 @@ bool profile_batched_gemm_impl(int do_verification,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideC, int StrideC,
int BatchCount = 1) int BatchCount)
{ {
bool pass = true; bool pass = true;
...@@ -393,7 +394,6 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -393,7 +394,6 @@ bool profile_batched_gemm_impl(int do_verification,
} }
else else
{ {
float err = check_error(c_g_m_n_host_result, c_g_m_n_device_result); float err = check_error(c_g_m_n_host_result, c_g_m_n_device_result);
pass = pass && (err < 1E-6); pass = pass && (err < 1E-6);
} }
......
#pragma once #pragma once
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -253,7 +255,8 @@ void profile_conv_bwd_data_impl(int do_verification, ...@@ -253,7 +255,8 @@ void profile_conv_bwd_data_impl(int do_verification,
{ {
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result); ck::utils::check_err(in_n_c_hi_wi_device_result.mData,
in_n_c_hi_wi_host_result.mData);
if(do_log) if(do_log)
{ {
......
#pragma once
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_conv_backward_weight.hpp"
#include "element_wise_operation.hpp"
#include "reference_conv_backward_weight.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_bwd_weight_instance {
using DeviceConvBwdWeightNoOpPtr =
DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvBwdWeightNoOpPtr>&);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvBwdWeightNoOpPtr>&);
} // namespace device_conv2d_bwd_weight_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace profiler {
template <int NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
bool profile_conv_bwd_weight_impl(int do_verification,
int init_method,
bool do_log,
int nrepeat,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
ck::index_t split_k)
{
const ck::index_t Y = filter_spatial_lengths[0];
const ck::index_t X = filter_spatial_lengths[1];
const ck::index_t Hi = input_spatial_lengths[0];
const ck::index_t Wi = input_spatial_lengths[1];
const ck::index_t Ho = output_spatial_lengths[0];
const ck::index_t Wo = output_spatial_lengths[1];
auto f_host_tensor_descriptor =
[](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) {
if constexpr(is_same<decltype(layout), ck::tensor_layout::convolution::NCHW>::value ||
is_same<decltype(layout), ck::tensor_layout::convolution::KCYX>::value ||
is_same<decltype(layout), ck::tensor_layout::convolution::NKHW>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
std::vector<std::size_t>({C_ * H * W, H * W, W, 1}));
}
else if constexpr(is_same<decltype(layout), tensor_layout::convolution::NHWC>::value ||
is_same<decltype(layout), tensor_layout::convolution::KYXC>::value ||
is_same<decltype(layout), tensor_layout::convolution::NHWK>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
std::vector<std::size_t>({C_ * H * W, 1, W * C_, C_}));
}
};
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{}));
Tensor<WeiDataType> wei_k_c_y_x_host_result(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{}));
Tensor<WeiDataType> wei_k_c_y_x_device_result(
f_host_tensor_descriptor(K, C, Y, X, WeiLayout{}));
Tensor<OutDataType> out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{}));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
break;
default:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
}
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
if(do_verification)
{
using ReferenceConvBwdWeightInstance =
ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
auto ref_conv = ReferenceConvBwdWeightInstance{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
wei_k_c_y_x_host_result,
out_n_k_ho_wo,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
ref_invoker.Run(ref_argument);
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) *
wei_k_c_y_x_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceConvBwdWeightNoOpPtr =
ck::tensor_operation::device::DeviceConvBwdWeightPtr<PassThrough, PassThrough, PassThrough>;
// add device Conv instances
std::vector<DeviceConvBwdWeightNoOpPtr> conv_ptrs;
if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
ck::tensor_operation::device::device_conv2d_bwd_weight_instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
}
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{
ck::tensor_operation::device::device_conv2d_bwd_weight_instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
if(conv_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device Conv instance found");
}
std::string best_conv_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device Conv instances
bool pass = true;
for(auto& conv_ptr : conv_ptrs)
{
// using atomic, so need to reset input
if(split_k > 1)
{
wei_device_buf.SetZero();
}
auto argument_ptr = conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k);
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string conv_name = conv_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
sizeof(WeiDataType) * (K * C * Y * X) +
sizeof(OutDataType) * (N * K * Ho * Wo);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << conv_name << std::endl;
if(tflops > best_tflops)
{
best_conv_name = conv_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
float max_error = check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result);
if(max_error > 8)
{
pass = false;
std::cout << "Fail info:" << conv_ptr->GetTypeString() << std::endl;
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "out: ", out_n_k_ho_wo.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "wei_device: ", wei_k_c_y_x_device_result.mData, ",")
<< std::endl;
}
}
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
#pragma once #pragma once
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -245,7 +247,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, ...@@ -245,7 +247,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
{ {
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); ck::utils::check_err(out_n_k_ho_wo_device_result.mData,
out_n_k_ho_wo_host_result.mData);
if(do_log) if(do_log)
{ {
......
#pragma once #pragma once
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -301,7 +302,8 @@ void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification, ...@@ -301,7 +302,8 @@ void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification,
{ {
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); ck::utils::check_err(out_n_k_ho_wo_device_result.mData,
out_n_k_ho_wo_host_result.mData);
if(do_log) if(do_log)
{ {
......
#pragma once #pragma once
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -233,7 +234,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -233,7 +234,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
{ {
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); ck::utils::check_err(out_n_k_ho_wo_device_result.mData,
out_n_k_ho_wo_host_result.mData);
if(do_log) if(do_log)
{ {
......
#pragma once #pragma once
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -253,7 +255,8 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -253,7 +255,8 @@ void profile_conv_fwd_impl(int do_verification,
{ {
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); ck::utils::check_err(out_n_k_ho_wo_device_result.mData,
out_n_k_ho_wo_host_result.mData);
if(do_log) if(do_log)
{ {
......
#pragma once #pragma once
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "conv_utils.hpp" #include "conv_fwd_util.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using BF16 = ushort; using BF16 = ck::bhalf_t;
using INT8 = int8_t; using INT8 = int8_t;
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -68,13 +68,13 @@ HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::siz ...@@ -68,13 +68,13 @@ HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::siz
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 3: { case 3: {
return ck::conv_util::GetHostTensorDescriptor(dims, InLayout{}); return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
} }
case 2: { case 2: {
return ck::conv_util::GetHostTensorDescriptor(dims, InLayout{}); return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
} }
case 1: { case 1: {
return ck::conv_util::GetHostTensorDescriptor(dims, InLayout{}); return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
} }
default: { default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!"); throw std::runtime_error("Unsupported number of spatial dimensions provided!");
...@@ -90,13 +90,13 @@ HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::s ...@@ -90,13 +90,13 @@ HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::s
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 3: { case 3: {
return ck::conv_util::GetHostTensorDescriptor(dims, WeiLayout{}); return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
} }
case 2: { case 2: {
return ck::conv_util::GetHostTensorDescriptor(dims, WeiLayout{}); return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
} }
case 1: { case 1: {
return ck::conv_util::GetHostTensorDescriptor(dims, WeiLayout{}); return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
} }
default: { default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!"); throw std::runtime_error("Unsupported number of spatial dimensions provided!");
...@@ -112,15 +112,14 @@ HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::siz ...@@ -112,15 +112,14 @@ HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::siz
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 3: { case 3: {
return ck::conv_util::GetHostTensorDescriptor(dims, OutLayout{}); return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
} }
case 2: { case 2: {
return ck::conv_util::GetHostTensorDescriptor(dims, OutLayout{}); return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
} }
case 1: { case 1: {
return ck::conv_util::GetHostTensorDescriptor(dims, OutLayout{}); return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
} }
default: { default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!"); throw std::runtime_error("Unsupported number of spatial dimensions provided!");
} }
...@@ -274,13 +273,13 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -274,13 +273,13 @@ bool profile_convnd_bwd_data_impl(int do_verification,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, const std::vector<ck::index_t>& input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, const std::vector<ck::index_t>& filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, const std::vector<ck::index_t>& output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, const std::vector<ck::index_t>& conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, const std::vector<ck::index_t>& conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, const std::vector<ck::index_t>& input_left_pads,
std::vector<ck::index_t> input_right_pads) const std::vector<ck::index_t>& input_right_pads)
{ {
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
...@@ -304,51 +303,49 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -304,51 +303,49 @@ bool profile_convnd_bwd_data_impl(int do_verification,
std::begin(output_spatial_lengths), std::begin(output_spatial_lengths),
std::end(output_spatial_lengths)); std::end(output_spatial_lengths));
Tensor<InDataType> in_n_c_hi_wi_host_result( Tensor<InDataType> input_host_result(
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial)); get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
Tensor<InDataType> in_n_c_hi_wi_device_result( Tensor<InDataType> input_device_result(
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial)); get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
Tensor<WeiDataType> wei_k_c_y_x( Tensor<WeiDataType> weights(
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial)); get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
Tensor<OutDataType> out_n_k_ho_wo( Tensor<OutDataType> output(
get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial)); get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; std::cout << "input: " << input_host_result.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; std::cout << "output: " << output.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); output.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
default: default:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}); output.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
} }
DeviceMem in_device_buf(sizeof(InDataType) * DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace());
in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_device_buf.ToDevice(output.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_device_buf.ToDevice(weights.mData.data());
// reset input to zero // reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0}); in_device_buf.SetZero();
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
auto RunReference = [&](auto& ref_conv) { auto RunReference = [&](auto& ref_conv) {
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, auto ref_argument = ref_conv.MakeArgument(input_host_result,
wei_k_c_y_x, weights,
out_n_k_ho_wo, output,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -358,48 +355,16 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -358,48 +355,16 @@ bool profile_convnd_bwd_data_impl(int do_verification,
OutElementOp{}); OutElementOp{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
}; };
switch(NDimSpatial)
{ auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
case 3: { WeiDataType,
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType, OutDataType,
WeiDataType, AccDataType,
OutDataType, InElementOp,
AccDataType, WeiElementOp,
InElementOp, OutElementOp,
WeiElementOp, NDimSpatial>();
OutElementOp, RunReference(ref_conv);
3>();
RunReference(ref_conv);
break;
}
case 2: {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
2>();
RunReference(ref_conv);
break;
}
case 1: {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
1>();
RunReference(ref_conv);
break;
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
} }
// add device Conv instances // add device Conv instances
...@@ -448,9 +413,10 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -448,9 +413,10 @@ bool profile_convnd_bwd_data_impl(int do_verification,
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
std::size_t flop = std::size_t flop =
ck::conv_util::GetFlops(N, C, K, filter_spatial_lengths, output_spatial_lengths); ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths);
std::size_t num_btype = ck::conv_util::GetBtype<InDataType, WeiDataType, OutDataType>( std::size_t num_btype =
N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
...@@ -468,9 +434,9 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -468,9 +434,9 @@ bool profile_convnd_bwd_data_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); in_device_buf.FromDevice(input_device_result.mData.data());
if(!check_out(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result)) if(!check_out(input_host_result, input_device_result))
{ {
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
...@@ -481,24 +447,24 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -481,24 +447,24 @@ bool profile_convnd_bwd_data_impl(int do_verification,
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
} }
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result); check_error(input_host_result, input_device_result);
if(do_log) if(do_log)
{ {
std::cout << "in : "; std::cout << "in : ";
show_data_nhwc_layout(out_n_k_ho_wo); show_data_nhwc_layout(output);
std::cout << std::endl; std::cout << std::endl;
std::cout << "wei: "; std::cout << "wei: ";
show_data_nhwc_layout(wei_k_c_y_x); show_data_nhwc_layout(weights);
std::cout << std::endl; std::cout << std::endl;
std::cout << "out_host : "; std::cout << "out_host : ";
show_data_nhwc_layout(in_n_c_hi_wi_host_result); show_data_nhwc_layout(input_host_result);
std::cout << std::endl; std::cout << std::endl;
std::cout << "out_device: "; std::cout << "out_device: ";
show_data_nhwc_layout(in_n_c_hi_wi_device_result); show_data_nhwc_layout(input_device_result);
std::cout << std::endl; std::cout << std::endl;
} }
} }
......
#pragma once #pragma once
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -283,7 +285,7 @@ void profile_gemm_bias_2d_impl(int do_verification, ...@@ -283,7 +285,7 @@ void profile_gemm_bias_2d_impl(int do_verification,
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
check_error(c_m_n_host_result, c_m_n_device_result); ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
......
#pragma once #pragma once
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -257,7 +259,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, ...@@ -257,7 +259,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
check_error(c_m_n_host_result, c_m_n_device_result); ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
......
#pragma once #pragma once
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -236,7 +238,7 @@ void profile_gemm_bias_relu_impl(int do_verification, ...@@ -236,7 +238,7 @@ void profile_gemm_bias_relu_impl(int do_verification,
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
check_error(c_m_n_host_result, c_m_n_device_result); ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
......
#pragma once #pragma once
#include <iomanip> #include <iomanip>
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -470,7 +472,7 @@ void profile_gemm_impl(int do_verification, ...@@ -470,7 +472,7 @@ void profile_gemm_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
check_error(c_m_n_host_result, c_m_n_device_f32_result); ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
...@@ -499,7 +501,7 @@ void profile_gemm_impl(int do_verification, ...@@ -499,7 +501,7 @@ void profile_gemm_impl(int do_verification,
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
check_error(c_m_n_host_result, c_m_n_device_result); ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
......
#pragma once #pragma once
#include <iomanip> #include <iomanip>
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -283,7 +285,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -283,7 +285,7 @@ void profile_grouped_gemm_impl(int do_verification,
c_element_op); c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
check_error(c_m_n_host_result, c_m_n_device_results[i]); ck::utils::check_err(c_m_n_device_results[i].mData, c_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
......
#pragma once #pragma once
#include "check_err.hpp"
#include "device_reduce.hpp" #include "device_reduce.hpp"
#include "device_reduce_instance.hpp" #include "device_reduce_instance.hpp"
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
...@@ -64,9 +66,9 @@ template <typename DescriptionType> ...@@ -64,9 +66,9 @@ template <typename DescriptionType>
bool description_match(const DescriptionType& description, bool description_match(const DescriptionType& description,
int Rank, int Rank,
const std::vector<int>& reduceDims, const std::vector<int>& reduceDims,
ReduceTensorOp_t ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation_t NanOpt, NanPropagation NanOpt,
ReduceTensorIndices_t IndicesOpt) ReduceTensorIndices IndicesOpt)
{ {
if(description.Rank_ != Rank || description.ReduceOpId_ != static_cast<int>(ReduceOpId) || if(description.Rank_ != Rank || description.ReduceOpId_ != static_cast<int>(ReduceOpId) ||
description.NanOpt_ != static_cast<int>(NanOpt) || description.NanOpt_ != static_cast<int>(NanOpt) ||
...@@ -148,9 +150,9 @@ template <typename InDataType, ...@@ -148,9 +150,9 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
ReduceTensorOp_t ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation_t NanOpt, NanPropagation NanOpt,
ReduceTensorIndices_t IndicesOpt> ReduceTensorIndices IndicesOpt>
void profile_reduce_impl_impl(bool do_verification, void profile_reduce_impl_impl(bool do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
...@@ -166,17 +168,17 @@ void profile_reduce_impl_impl(bool do_verification, ...@@ -166,17 +168,17 @@ void profile_reduce_impl_impl(bool do_verification,
using namespace ck::host_reduce; using namespace ck::host_reduce;
constexpr bool op_support_indices = constexpr bool op_support_indices =
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp_t::AMAX); ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool NeedIndices = constexpr bool NeedIndices =
(op_support_indices && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES)); (op_support_indices && (IndicesOpt != ReduceTensorIndices::NO_INDICES));
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::PROPAGATE_NAN); constexpr bool PropagateNan = (NanOpt == NanPropagation::PROPAGATE_NAN);
constexpr bool out_support_atomic_add = std::is_same<OutDataType, float>::value; constexpr bool out_support_atomic_add = std::is_same<OutDataType, float>::value;
constexpr bool op_support_atomic_add = constexpr bool op_support_atomic_add =
!op_support_indices && ReduceOpId != ReduceTensorOp_t::NORM2; !op_support_indices && ReduceOpId != ReduceTensorOp::NORM2;
constexpr bool use_atomic_add = (out_support_atomic_add && op_support_atomic_add); constexpr bool use_atomic_add = (out_support_atomic_add && op_support_atomic_add);
// 1) If InDataType is half_t, must use half_t as AccDataType for indexable reduction operations // 1) If InDataType is half_t, must use half_t as AccDataType for indexable reduction operations
...@@ -194,7 +196,7 @@ void profile_reduce_impl_impl(bool do_verification, ...@@ -194,7 +196,7 @@ void profile_reduce_impl_impl(bool do_verification,
// 1) The indices can only be used when the reduction operation is indexable // 1) The indices can only be used when the reduction operation is indexable
constexpr bool invalid_reduce_3 = constexpr bool invalid_reduce_3 =
(!op_support_indices && IndicesOpt != ReduceTensorIndices_t::NO_INDICES); (!op_support_indices && IndicesOpt != ReduceTensorIndices::NO_INDICES);
// 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations // 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations
// 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction // 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction
...@@ -207,8 +209,8 @@ void profile_reduce_impl_impl(bool do_verification, ...@@ -207,8 +209,8 @@ void profile_reduce_impl_impl(bool do_verification,
// 1) If InDataType is int8_t, the supported operation must be either indexable operations or // 1) If InDataType is int8_t, the supported operation must be either indexable operations or
// ADD/AVG // ADD/AVG
constexpr bool invalid_reduce_5 = std::is_same<InDataType, int8_t>::value && constexpr bool invalid_reduce_5 = std::is_same<InDataType, int8_t>::value &&
(!op_support_indices && ReduceOpId != ReduceTensorOp_t::ADD && (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD &&
ReduceOpId != ReduceTensorOp_t::AVG); ReduceOpId != ReduceTensorOp::AVG);
// 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations // 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations
constexpr bool invalid_reduce_6 = constexpr bool invalid_reduce_6 =
...@@ -455,12 +457,13 @@ void profile_reduce_impl_impl(bool do_verification, ...@@ -455,12 +457,13 @@ void profile_reduce_impl_impl(bool do_verification,
if(do_verification) if(do_verification)
{ {
out_dev.FromDevice(out.mData.data()); out_dev.FromDevice(out.mData.data());
check_error(out_ref, out); ck::utils::check_err(out.mData, out_ref.mData);
if(NeedIndices) if(NeedIndices)
{ {
out_indices_dev.FromDevice(out_indices.mData.data()); out_indices_dev.FromDevice(out_indices.mData.data());
check_indices(out_indices_ref, out_indices); ck::utils::check_err(out_indices.mData, out_indices_ref.mData);
;
}; };
if(do_log) if(do_log)
...@@ -577,12 +580,13 @@ void profile_reduce_impl_impl(bool do_verification, ...@@ -577,12 +580,13 @@ void profile_reduce_impl_impl(bool do_verification,
if(do_verification) if(do_verification)
{ {
out_dev.FromDevice(out.mData.data()); out_dev.FromDevice(out.mData.data());
check_error(out_ref, out); ck::utils::check_err(out.mData, out_ref.mData);
if(NeedIndices) if(NeedIndices)
{ {
out_indices_dev.FromDevice(out_indices.mData.data()); out_indices_dev.FromDevice(out_indices.mData.data());
check_indices(out_indices_ref, out_indices); ck::utils::check_err(out_indices.mData, out_indices_ref.mData);
;
}; };
if(do_log) if(do_log)
...@@ -631,9 +635,9 @@ void profile_reduce_impl(bool do_verification, ...@@ -631,9 +635,9 @@ void profile_reduce_impl(bool do_verification,
int nrepeat, int nrepeat,
const std::vector<size_t>& inLengths, const std::vector<size_t>& inLengths,
const std::vector<int>& reduceDims, const std::vector<int>& reduceDims,
ReduceTensorOp_t ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation_t NanOpt, NanPropagation NanOpt,
ReduceTensorIndices_t IndicesOpt, ReduceTensorIndices IndicesOpt,
float alpha, float alpha,
float beta) float beta)
{ {
...@@ -659,9 +663,9 @@ void profile_reduce_impl(bool do_verification, ...@@ -659,9 +663,9 @@ void profile_reduce_impl(bool do_verification,
OutDataType, OutDataType,
descType::Rank_, descType::Rank_,
descType::NumReduceDim_, descType::NumReduceDim_,
static_cast<ReduceTensorOp_t>(descType::ReduceOpId_), static_cast<ReduceTensorOp>(descType::ReduceOpId_),
static_cast<NanPropagation_t>(descType::NanOpt_), static_cast<NanPropagation>(descType::NanOpt_),
static_cast<ReduceTensorIndices_t>(descType::IndicesOpt_)>( static_cast<ReduceTensorIndices>(descType::IndicesOpt_)>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
......
...@@ -128,7 +128,8 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -128,7 +128,8 @@ int profile_batched_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
BatchCount);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
...@@ -147,7 +148,8 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -147,7 +148,8 @@ int profile_batched_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
BatchCount);
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
...@@ -206,7 +208,8 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -206,7 +208,8 @@ int profile_batched_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
BatchCount);
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
...@@ -225,7 +228,8 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -225,7 +228,8 @@ int profile_batched_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
BatchCount);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
...@@ -284,7 +288,8 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -284,7 +288,8 @@ int profile_batched_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
BatchCount);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
...@@ -303,7 +308,8 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -303,7 +308,8 @@ int profile_batched_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
BatchCount);
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
...@@ -362,7 +368,8 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -362,7 +368,8 @@ int profile_batched_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
BatchCount);
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
int profile_batched_gemm_reduce(int argc, char* argv[]) int profile_batched_gemm_reduce(int argc, char* argv[])
{ {
enum struct GemmMatrixLayout_t enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -17,7 +17,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -17,7 +17,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
KM_NK_MN, // 3 KM_NK_MN, // 3
}; };
enum struct GemmReduceDataType_t enum struct GemmReduceDataType
{ {
F32_F32_F32_F32_F32, // 0 F32_F32_F32_F32_F32, // 0
F16_F16_F16_F32_F32, // 1 F16_F16_F16_F32_F32, // 1
...@@ -40,8 +40,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -40,8 +40,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
exit(1); exit(1);
} }
const auto data_type = static_cast<GemmReduceDataType_t>(std::stoi(argv[2])); const auto data_type = static_cast<GemmReduceDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout_t>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
...@@ -57,8 +57,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -57,8 +57,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
const int BatchCount = std::stoi(argv[14]); const int BatchCount = std::stoi(argv[14]);
if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
layout == GemmMatrixLayout_t::MK_KN_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
...@@ -79,8 +78,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -79,8 +78,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC,
BatchCount); BatchCount);
} }
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::MK_NK_MN) layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
...@@ -101,8 +100,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -101,8 +100,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC,
BatchCount); BatchCount);
} }
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::KM_KN_MN) layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
...@@ -123,8 +122,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -123,8 +122,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC,
BatchCount); BatchCount);
} }
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::KM_NK_MN) layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment