Commit e4e99a49 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use new utilities to shorten codes

parent 7acbf104
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <initializer_list>
#include <iostream>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_gemm.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/utility/host_gemm.hpp"
enum struct GemmMatrixLayout enum struct GemmMatrixLayout
{ {
...@@ -29,13 +29,19 @@ enum struct GemmMatrixLayout ...@@ -29,13 +29,19 @@ enum struct GemmMatrixLayout
}; };
template <typename T> template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result) static std::enable_if_t<std::is_convertible_v<T, double>, bool> check_out(const Tensor<T>& ref,
const Tensor<T>& out)
{ {
float max_diff = 1e-6; if(out.size() != ref.size())
{
return false;
}
constexpr float max_diff = 1e-6;
for(std::size_t i = 0; i < ref.mData.size(); ++i) auto o = out.begin();
for(auto r = ref.begin(); r != ref.end(); ++r, ++o)
{ {
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); const float diff = std::abs(double(*r) - double(*o));
if(max_diff < diff) if(max_diff < diff)
{ {
return false; return false;
...@@ -91,17 +97,17 @@ int test_gemm(const gemmArgs& args) ...@@ -91,17 +97,17 @@ int test_gemm(const gemmArgs& args)
default: printf("not supported layout"); return 1; default: printf("not supported layout"); return 1;
} }
using namespace ck::literals;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, bool row_major) { [](std::size_t row, std::size_t col, std::size_t stride, bool row_major) {
if(row_major) if(row_major)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor({row, col}, {stride, 1_uz});
std::vector<std::size_t>({stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor({row, col}, {1_uz, stride});
std::vector<std::size_t>({1, stride}));
} }
}; };
...@@ -126,13 +132,13 @@ int test_gemm(const gemmArgs& args) ...@@ -126,13 +132,13 @@ int test_gemm(const gemmArgs& args)
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a_m_k.GetMemorySize());
DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b_k_n.GetMemorySize());
DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(c_m_n_device_result.GetMemorySize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.data());
auto test = [&](auto a_layout, auto b_layout, auto c_layout) { auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool success = false; bool success = false;
...@@ -174,7 +180,7 @@ int test_gemm(const gemmArgs& args) ...@@ -174,7 +180,7 @@ int test_gemm(const gemmArgs& args)
{ {
invoker_ptr->Run(argument_ptr.get()); invoker_ptr->Run(argument_ptr.get());
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.data());
if(!check_out(c_m_n_host_result, c_m_n_device_result)) if(!check_out(c_m_n_host_result, c_m_n_device_result))
{ {
......
...@@ -3,18 +3,20 @@ ...@@ -3,18 +3,20 @@
#pragma once #pragma once
#include <vector>
#include <iostream> #include <iostream>
#include <vector>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" #include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/ranges.hpp"
namespace ck { namespace ck {
...@@ -105,29 +107,31 @@ class TestLayernorm2d : public ::testing::Test ...@@ -105,29 +107,31 @@ class TestLayernorm2d : public ::testing::Test
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0}); gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0}); beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(x.GetMemorySize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(gamma.GetMemorySize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(beta.GetMemorySize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(y.GetMemorySize());
x_dev.ToDevice(x.data());
gamma_dev.ToDevice(gamma.data());
beta_dev.ToDevice(beta.data());
x_dev.ToDevice(x.mData.data()); using Indices = std::vector<ck::index_t>;
gamma_dev.ToDevice(gamma.mData.data());
beta_dev.ToDevice(beta.mData.data());
auto device_instance = DeviceInstance{}; auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer( auto argument_ptr =
lengths, device_instance.MakeArgumentPointer(lengths,
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, ck::ranges::to<Indices>(x.GetStrides()),
GammaStride, GammaStride,
BetaStride, BetaStride,
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, ck::ranges::to<Indices>(y.GetStrides()),
reduceDims, reduceDims,
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
PassThrough{}); PassThrough{});
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -140,19 +144,18 @@ class TestLayernorm2d : public ::testing::Test ...@@ -140,19 +144,18 @@ class TestLayernorm2d : public ::testing::Test
ref_instance_invoker_.Run( ref_instance_invoker_.Run(
{x, gamma, beta, y_ref, PassThrough{}, lengths, reduceDims, 1e-4}); {x, gamma, beta, y_ref, PassThrough{}, lengths, reduceDims, 1e-4});
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.data());
bool pass; bool pass;
if(std::is_same<XDataType, int8_t>::value) if constexpr(std::is_same_v<XDataType, int8_t>)
{ {
EXPECT_TRUE(pass = ck::utils::check_err( EXPECT_TRUE(pass = ck::utils::check_err(y, y_ref, "Error: Incorrect results!", 0, 1));
y.mData, y_ref.mData, "Error: Incorrect results!", 0, 1));
} }
else else
{ {
EXPECT_TRUE(pass = ck::utils::check_err( EXPECT_TRUE(
y.mData, y_ref.mData, "Error: Incorrect results d1", 1e-3, 1e-3)); pass = ck::utils::check_err(y, y_ref, "Error: Incorrect results d1", 1e-3, 1e-3));
} }
if(!pass) if(!pass)
......
...@@ -6,18 +6,19 @@ ...@@ -6,18 +6,19 @@
#include <numeric> #include <numeric>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/fill.hpp" #include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
namespace { namespace {
...@@ -121,8 +122,8 @@ TEST(ReferenceConvolutionFWD, Conv2DGNHWC) ...@@ -121,8 +122,8 @@ TEST(ReferenceConvolutionFWD, Conv2DGNHWC)
490.5, 490.5,
508.5}; 508.5};
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); out_tensor.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!"));
EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); EXPECT_TRUE(ck::utils::check_err(out_tensor, ref_data, "Error: incorrect results!"));
} }
TEST(ReferenceConvolutionFWD, Conv2DGNHWCStridesDilationsPadding) TEST(ReferenceConvolutionFWD, Conv2DGNHWCStridesDilationsPadding)
...@@ -140,7 +141,7 @@ TEST(ReferenceConvolutionFWD, Conv2DGNHWCStridesDilationsPadding) ...@@ -140,7 +141,7 @@ TEST(ReferenceConvolutionFWD, Conv2DGNHWCStridesDilationsPadding)
std::vector<ck::index_t>{1, 1}); std::vector<ck::index_t>{1, 1});
auto out_tensor = run_reference_convolution_forward<2>(conv_param); auto out_tensor = run_reference_convolution_forward<2>(conv_param);
std::vector<std::size_t> ref_dims = std::vector<std::size_t>{1, 5, 5, 2}; std::vector<std::size_t> ref_dims{1, 5, 5, 2};
std::vector<float> ref_data{ std::vector<float> ref_data{
210., 210., 327., 327., 351., 351., 375., 375., 399., 399., 210., 210., 327., 327., 351., 351., 375., 375., 399., 399.,
459., 459., 706.5, 706.5, 742.5, 742.5, 778.5, 778.5, 814.5, 814.5, 459., 459., 706.5, 706.5, 742.5, 742.5, 778.5, 778.5, 814.5, 814.5,
...@@ -148,8 +149,8 @@ TEST(ReferenceConvolutionFWD, Conv2DGNHWCStridesDilationsPadding) ...@@ -148,8 +149,8 @@ TEST(ReferenceConvolutionFWD, Conv2DGNHWCStridesDilationsPadding)
1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5, 1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5,
1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5}; 1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5};
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); out_tensor.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!"));
EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); EXPECT_TRUE(ck::utils::check_err(out_tensor, ref_data, "Error: incorrect results!"));
} }
TEST(ReferenceConvolutionFWD, Conv1DGNWC) TEST(ReferenceConvolutionFWD, Conv1DGNWC)
...@@ -177,8 +178,8 @@ TEST(ReferenceConvolutionFWD, Conv1DGNWC) ...@@ -177,8 +178,8 @@ TEST(ReferenceConvolutionFWD, Conv1DGNWC)
std::vector<std::size_t> ref_dims{1, 1, 4, 1}; std::vector<std::size_t> ref_dims{1, 1, 4, 1};
std::vector<float> ref_data{7.5, 13.5, 19.5, 25.5}; std::vector<float> ref_data{7.5, 13.5, 19.5, 25.5};
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); out_tensor.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!"));
EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); EXPECT_TRUE(ck::utils::check_err(out_tensor, ref_data, "Error: incorrect results!"));
} }
TEST(ReferenceConvolutionFWD, Conv1DGNWCStridesDilationsPadding) TEST(ReferenceConvolutionFWD, Conv1DGNWCStridesDilationsPadding)
...@@ -206,8 +207,8 @@ TEST(ReferenceConvolutionFWD, Conv1DGNWCStridesDilationsPadding) ...@@ -206,8 +207,8 @@ TEST(ReferenceConvolutionFWD, Conv1DGNWCStridesDilationsPadding)
std::vector<std::size_t> ref_dims{1, 1, 5, 2}; std::vector<std::size_t> ref_dims{1, 1, 5, 2};
std::vector<float> ref_data{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; std::vector<float> ref_data{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5};
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); out_tensor.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!"));
EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); EXPECT_TRUE(ck::utils::check_err(out_tensor, ref_data, "Error: incorrect results!"));
} }
TEST(ReferenceConvolutionFWD, Conv1DGNWCSameOutputSize) TEST(ReferenceConvolutionFWD, Conv1DGNWCSameOutputSize)
...@@ -300,8 +301,8 @@ TEST(ReferenceConvolutionFWD, Conv1DGNWCSameOutputSize) ...@@ -300,8 +301,8 @@ TEST(ReferenceConvolutionFWD, Conv1DGNWCSameOutputSize)
49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4,
49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4}; 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4};
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_tensor2.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); out_tensor2.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!"));
EXPECT_TRUE(ck::utils::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!")); EXPECT_TRUE(ck::utils::check_err(out_tensor2, ref_data, "Error: incorrect results!"));
} }
#endif #endif
...@@ -337,11 +338,9 @@ TEST(ReferenceConvolutionFWD, Conv3DGNCDHW) ...@@ -337,11 +338,9 @@ TEST(ReferenceConvolutionFWD, Conv3DGNCDHW)
634.5, 637.2, 639.9, 642.60004, 650.7, 653.4, 656.10004, 658.8, 634.5, 637.2, 639.9, 642.60004, 650.7, 653.4, 656.10004, 658.8,
699.3, 702., 704.7, 707.4, 715.5, 718.2, 720.9, 723.60004, 699.3, 702., 704.7, 707.4, 715.5, 718.2, 720.9, 723.60004,
731.7, 734.4001, 737.10004, 739.8, 747.9001, 750.60004, 753.3, 756.}; 731.7, 734.4001, 737.10004, 739.8, 747.9001, 750.60004, 753.3, 756.};
EXPECT_TRUE(ck::utils::check_err(out_tensor.mDesc.GetLengths(), EXPECT_TRUE(ck::utils::check_err(
ref_dims, out_tensor.GetLengths(), ref_dims, "Error [case 1]: wrong output tensor dimensions!"));
"Error [case 1]: wrong output tensor dimensions!")); EXPECT_TRUE(ck::utils::check_err(out_tensor, ref_data, "Error [case 1]: incorrect results!"));
EXPECT_TRUE(
ck::utils::check_err(out_tensor.mData, ref_data, "Error [case 1]: incorrect results!"));
} }
TEST(ReferenceConvolutionFWD, Conv3DGNCDHWStridesDilations) TEST(ReferenceConvolutionFWD, Conv3DGNCDHWStridesDilations)
...@@ -384,9 +383,8 @@ TEST(ReferenceConvolutionFWD, Conv3DGNCDHWStridesDilations) ...@@ -384,9 +383,8 @@ TEST(ReferenceConvolutionFWD, Conv3DGNCDHWStridesDilations)
5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004,
6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4,
6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801}; 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801};
EXPECT_TRUE(ck::utils::check_err(out_tensor.mDesc.GetLengths(),
ref_dims,
"Error [case 2]: wrong output tensor dimensions!"));
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_tensor.mData, ref_data, "Error [case 2]: incorrect results!", 1e-4f, 1e-6f)); out_tensor.GetLengths(), ref_dims, "Error [case 2]: wrong output tensor dimensions!"));
EXPECT_TRUE(ck::utils::check_err(
out_tensor, ref_data, "Error [case 2]: incorrect results!", 1e-4f, 1e-6f));
} }
...@@ -3,19 +3,20 @@ ...@@ -3,19 +3,20 @@
#pragma once #pragma once
#include <vector>
#include <iostream> #include <iostream>
#include <vector>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/number.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
...@@ -85,15 +86,13 @@ class TestSoftmax : public ::testing::Test ...@@ -85,15 +86,13 @@ class TestSoftmax : public ::testing::Test
Tensor<OutDataType> out_ref(out); Tensor<OutDataType> out_ref(out);
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); DeviceMem in_dev(in.GetMemorySize());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); DeviceMem out_dev(out.GetMemorySize());
in_dev.ToDevice(in.mData.data()); in_dev.ToDevice(in.data());
out_dev.ToDevice(out.mData.data()); out_dev.ToDevice(out.data());
std::vector<index_t> i_in_lengths(in.mDesc.GetLengths().begin(), std::vector<index_t> i_in_lengths(in.GetLengths().begin(), in.GetLengths().end());
in.mDesc.GetLengths().end()); std::vector<index_t> i_in_strides(in.GetStrides().begin(), in.GetStrides().end());
std::vector<index_t> i_in_strides(in.mDesc.GetStrides().begin(),
in.mDesc.GetStrides().end());
auto device_instance = DeviceInstance{}; auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer(i_in_lengths, auto argument_ptr = device_instance.MakeArgumentPointer(i_in_lengths,
...@@ -119,18 +118,18 @@ class TestSoftmax : public ::testing::Test ...@@ -119,18 +118,18 @@ class TestSoftmax : public ::testing::Test
ref_instance_invoker_.Run({in, out_ref, alpha, beta, reduce_dims}); ref_instance_invoker_.Run({in, out_ref, alpha, beta, reduce_dims});
out_dev.FromDevice(out.mData.data()); out_dev.FromDevice(out.data());
bool pass; bool pass;
if(std::is_same<InDataType, int8_t>::value) if constexpr(std::is_same_v<InDataType, int8_t>)
{ {
EXPECT_TRUE(pass = ck::utils::check_err( EXPECT_TRUE(pass =
out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1)); ck::utils::check_err(out, out_ref, "Error: Incorrect results!", 0, 1));
} }
else else
{ {
EXPECT_TRUE(pass = ck::utils::check_err(out.mData, out_ref.mData)); EXPECT_TRUE(pass = ck::utils::check_err(out, out_ref));
} }
if(!pass) if(!pass)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment