Commit e3976f16 authored by rocking's avatar rocking
Browse files

Change to use check_err()

parent aa027054
#pragma once
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
......@@ -63,7 +64,7 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_gemm_bias_add_reduce_impl(int do_verification,
void profile_gemm_bias_add_reduce_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
......@@ -75,8 +76,6 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification,
int StrideC,
int StrideC1)
{
bool pass = true;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
......@@ -353,13 +352,9 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification,
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
float c_error = check_error(c_m_n_host_result, c_m_n_device_result);
float d0_error = check_error(d0_m_host_result, d0_m_device_result);
float d1_error = check_error(d1_m_host_result, d1_m_device_result);
pass = pass && (c_error < 1E-6);
pass = pass && (d0_error < 1E-6);
pass = pass && (d1_error < 1E-6);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
ck::utils::check_err(d0_m_device_result.mData, d0_m_host_result.mData);
ck::utils::check_err(d1_m_device_result.mData, d1_m_host_result.mData);
if(do_log)
{
......@@ -388,8 +383,6 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
}
} // namespace profiler
......
#pragma once
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
......@@ -312,13 +313,9 @@ bool profile_gemm_reduce_impl(int do_verification,
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
float c_error = check_error(c_m_n_host_result, c_m_n_device_result);
float d0_error = check_error(d0_m_host_result, d0_m_device_result);
float d1_error = check_error(d1_m_host_result, d1_m_device_result);
pass = pass && (c_error < 1E-6);
pass = pass && (d0_error < 1E-6);
pass = pass && (d1_error < 1E-6);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
ck::utils::check_err(d0_m_device_result.mData, d0_m_host_result.mData);
ck::utils::check_err(d1_m_device_result.mData, d1_m_host_result.mData);
if(do_log)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment