Commit 3e141c95 authored by Paul's avatar Paul
Browse files

Add finish method to context

parent 74b3d019
......@@ -17,19 +17,21 @@ namespace migraph {
/// during `eval`.
struct context
{
/// Wait for any tasks in the context to complete
void finish() const;
};
#else
/*
* Type-erased interface for:
*
* struct context
* {
* };
*
*/
* Type-erased interface for:
*
* struct context
* {
* void finish() const;
* };
*
*/
struct context
{
......@@ -88,12 +90,20 @@ struct context
return private_detail_te_get_handle().type();
}
void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().finish();
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void finish() const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -124,6 +134,8 @@ struct context
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void finish() const override { return private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value;
};
......
......@@ -7,6 +7,8 @@
#include <iostream>
#include <numeric>
#include <migraph/float_equal.hpp>
namespace migraph {
// Compute the value of a range
......@@ -101,7 +103,7 @@ auto range_distance(R1&& r1)
template <class R1>
bool range_zero(R1&& r1)
{
return std::all_of(r1.begin(), r1.end(), [](auto x) { return x == 0; });
return std::all_of(r1.begin(), r1.end(), [](auto x) { return float_equal(x, 0); });
}
template <class R1, class R2, class T, class Reducer, class Product>
......
......@@ -331,28 +331,30 @@ double common_average(const std::vector<double>& v)
void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const
{
using milliseconds = std::chrono::duration<double, std::milli>;
auto& ctx = this->impl->ctx;
// Run once by itself
eval(params);
ctx.finish();
// Run and time entire program
std::vector<double> total_vec;
total_vec.reserve(n);
for(std::size_t i = 0; i < n; i++)
{
total_vec.push_back(time<milliseconds>([&] { eval(params); }));
total_vec.push_back(time<milliseconds>([&] { eval(params); ctx.finish(); }));
}
std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map
generic_eval(*this, this->impl->ctx, params, [&](auto ins, auto) {
generic_eval(*this, ctx, params, [&](auto ins, auto) {
ins_vec[ins].reserve(n);
return argument{};
});
// Run and time each instruction
for(std::size_t i = 0; i < n; i++)
{
generic_eval(*this, this->impl->ctx, params, [&](auto ins, auto f) {
generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result;
ins_vec[ins].push_back(time<milliseconds>([&] { result = f(); }));
ins_vec[ins].push_back(time<milliseconds>([&] { result = f(); ctx.finish(); }));
return result;
});
}
......@@ -364,7 +366,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
for(std::size_t i = 0; i < n; i++)
{
overhead_vec.push_back(time<milliseconds>([&] {
generic_eval(*this, this->impl->ctx, params, [](auto...) { return argument{}; });
generic_eval(*this, ctx, params, [](auto...) { return argument{}; });
}));
}
......
......@@ -8,7 +8,7 @@ namespace cpu {
std::string cpu_target::name() const { return "cpu"; }
std::vector<pass> cpu_target::get_passes(context&) const
std::vector<pass> cpu_target::get_passes(migraph::context&) const
{
return {auto_contiguous{}, cpu_lowering{}};
}
......
#ifndef MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
namespace migraph {
namespace cpu {
struct context
{
void finish() const
{}
};
} // namespace cpu
} // namespace migraph
#endif
......@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_CPU_TARGET_HPP
#include <migraph/program.hpp>
#include <migraph/cpu/context.hpp>
namespace migraph {
namespace cpu {
......@@ -9,8 +10,8 @@ namespace cpu {
struct cpu_target
{
std::string name() const;
std::vector<pass> get_passes(context& ctx) const;
context get_context() const { return {}; }
std::vector<pass> get_passes(migraph::context& ctx) const;
migraph::context get_context() const { return context{}; }
};
} // namespace cpu
......
......@@ -3,6 +3,7 @@
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/hip.hpp>
namespace migraph {
namespace gpu {
......@@ -11,6 +12,10 @@ struct context
{
shared<miopen_handle> handle;
shared<rocblas_handle_ptr> rbhandle;
void finish() const
{
gpu_sync();
}
};
} // namespace gpu
......
......@@ -332,7 +332,6 @@ struct miopen_apply
check_shape(s, apply_batch_norm_inference(it));
}
}
prog->insert_instruction(prog->end(), hip_sync{}, std::prev(prog->end()));
}
instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "")
......
......@@ -127,6 +127,29 @@ void verify_program()
{
// TODO: Check for nans
std::cout << "FAILED: " << migraph::get_type_name<V>() << std::endl;
// std::cout << cpu << std::endl;
// std::cout << gpu << std::endl;
if(migraph::range_zero(cpu))
std::cout << "Cpu data is all zeros" << std::endl;
if(migraph::range_zero(gpu))
std::cout << "Gpu data is all zeros" << std::endl;
auto idx = migraph::mismatch_idx(cpu, gpu, migraph::float_equal);
if(idx < migraph::range_distance(cpu))
{
std::cout << "Mismatch at " << idx << ": " << cpu[idx] << " != " << gpu[idx]
<< std::endl;
}
auto cpu_nan_idx = find_idx(cpu, migraph::not_finite);
if(cpu_nan_idx >= 0)
std::cout << "Non finite number found in cpu at " << cpu_nan_idx << ": "
<< cpu[cpu_nan_idx] << std::endl;
auto gpu_nan_idx = find_idx(gpu, migraph::not_finite);
if(gpu_nan_idx >= 0)
std::cout << "Non finite number found in gpu at " << gpu_nan_idx << ": "
<< gpu[gpu_nan_idx] << std::endl;
}
});
std::set_terminate(nullptr);
......
......@@ -17,12 +17,16 @@ namespace migraph {
/// during `eval`.
struct context
{
/// Wait for any tasks in the context to complete
void finish() const;
};
#else
<%
interface('context')
interface('context',
virtual('finish', returns='void', const=True)
)
%>
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment