Commit df78aadf authored by wsttiger's avatar wsttiger
Browse files

merged from master

parents ba934fc2 58681660
...@@ -16,8 +16,8 @@ void dead_code_elimination::apply(program& p) const ...@@ -16,8 +16,8 @@ void dead_code_elimination::apply(program& p) const
if(ins == p.begin()) if(ins == p.begin())
continue; continue;
const auto i = std::prev(ins); const auto i = std::prev(ins);
// Skip instruction with empty shape as output // Skip instruction with empty shape as output unless its a builtin
if(i->result.elements() == 0) if(i->result.elements() == 0 and not(i->op.name().front() == '@'))
continue; continue;
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
namespace migraph { namespace migraph {
argument generate_argument(shape s, std::mt19937::result_type seed) argument generate_argument(shape s, unsigned long seed)
{ {
argument result; argument result;
s.visit_type([&](auto as) { s.visit_type([&](auto as) {
...@@ -13,7 +13,7 @@ argument generate_argument(shape s, std::mt19937::result_type seed) ...@@ -13,7 +13,7 @@ argument generate_argument(shape s, std::mt19937::result_type seed)
return result; return result;
} }
literal generate_literal(shape s, std::mt19937::result_type seed) literal generate_literal(shape s, unsigned long seed)
{ {
literal result; literal result;
s.visit_type([&](auto as) { s.visit_type([&](auto as) {
...@@ -24,4 +24,10 @@ literal generate_literal(shape s, std::mt19937::result_type seed) ...@@ -24,4 +24,10 @@ literal generate_literal(shape s, std::mt19937::result_type seed)
return result; return result;
} }
// TODO: Move to literal.cpp
literal abs(literal l)
{
return transform(std::move(l), [](auto x) { return std::fabs(x); });
}
} // namespace migraph } // namespace migraph
...@@ -97,8 +97,8 @@ struct check_shapes ...@@ -97,8 +97,8 @@ struct check_shapes
const check_shapes& not_broadcasted() const const check_shapes& not_broadcasted() const
{ {
// if(!this->all_of([](const shape& s) { return not s.broadcasted(); })) if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
// MIGRAPH_THROW(prefix() + "Shapes are broadcasted"); MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
return *this; return *this;
} }
......
...@@ -17,19 +17,21 @@ namespace migraph { ...@@ -17,19 +17,21 @@ namespace migraph {
/// during `eval`. /// during `eval`.
struct context struct context
{ {
/// Wait for any tasks in the context to complete
void finish() const;
}; };
#else #else
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
* struct context * struct context
* { * {
* void finish() const;
* }; * };
* *
*/ */
struct context struct context
{ {
...@@ -88,12 +90,20 @@ struct context ...@@ -88,12 +90,20 @@ struct context
return private_detail_te_get_handle().type(); return private_detail_te_get_handle().type();
} }
void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().finish();
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
virtual ~private_detail_te_handle_base_type() {} virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual void finish() const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -124,6 +134,8 @@ struct context ...@@ -124,6 +134,8 @@ struct context
const std::type_info& type() const override { return typeid(private_detail_te_value); } const std::type_info& type() const override { return typeid(private_detail_te_value); }
void finish() const override { return private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -7,13 +7,37 @@ ...@@ -7,13 +7,37 @@
namespace migraph { namespace migraph {
template <class T, MIGRAPH_REQUIRES(std::is_floating_point<T>{})>
constexpr T normalize(unsigned long z)
{
if(z == 0)
return 0;
return (2.0 / z) - 1.0;
}
template <class T, MIGRAPH_REQUIRES(std::is_signed<T>{} and not std::is_floating_point<T>{})>
constexpr T normalize(unsigned long z)
{
const auto max = std::numeric_limits<T>::max();
const auto half_max = max / 2;
return half_max - (z % max);
}
template <class T, MIGRAPH_REQUIRES(not std::is_signed<T>{} and std::is_integral<T>{})>
constexpr T normalize(unsigned long z)
{
const auto max = std::numeric_limits<T>::max();
return z % max;
}
template <class T> template <class T>
struct xorshf96_generator struct xorshf96_generator
{ {
unsigned long max = 31;
unsigned long x = 123456789; unsigned long x = 123456789;
unsigned long y = 362436069; unsigned long y = 362436069;
unsigned long z = 521288629; unsigned long z;
xorshf96_generator(unsigned long seed = 0) : z(521288629ULL ^ seed) {}
constexpr T operator()() noexcept constexpr T operator()() noexcept
{ {
...@@ -26,21 +50,23 @@ struct xorshf96_generator ...@@ -26,21 +50,23 @@ struct xorshf96_generator
y = z; y = z;
z = t ^ x ^ y; z = t ^ x ^ y;
return z % max; return normalize<T>(z);
} }
}; };
template <class T> template <class T>
std::vector<T> generate_tensor_data(const migraph::shape& s, std::mt19937::result_type) std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed = 0)
{ {
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{}); std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
return result; return result;
} }
argument generate_argument(shape s, std::mt19937::result_type seed = 0); argument generate_argument(shape s, unsigned long seed = 0);
literal generate_literal(shape s, unsigned long seed = 0);
literal generate_literal(shape s, std::mt19937::result_type seed = 0); literal abs(literal l);
} // namespace migraph } // namespace migraph
......
...@@ -94,6 +94,19 @@ struct literal : raw_data<literal> ...@@ -94,6 +94,19 @@ struct literal : raw_data<literal>
} }
}; };
template <class F>
literal transform(literal l, F f)
{
literal result;
l.visit([&](auto x) {
using type = std::remove_cv_t<typename decltype(x)::value_type>;
std::vector<type> output(x.size(), 0.0);
std::transform(x.begin(), x.end(), output.begin(), f);
result = literal{l.get_shape(), output};
});
return result;
}
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -133,7 +133,7 @@ struct convolution ...@@ -133,7 +133,7 @@ struct convolution
struct pooling struct pooling
{ {
std::string mode; std::string mode = "average";
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}}; std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}}; std::array<std::size_t, 2> lengths = {{1, 1}};
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraph/builtin.hpp> #include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp> #include <migraph/instruction_ref.hpp>
#include <migraph/target.hpp> #include <migraph/target.hpp>
#include <migraph/tracer.hpp>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
...@@ -88,7 +89,7 @@ struct program ...@@ -88,7 +89,7 @@ struct program
instruction_ref validate() const; instruction_ref validate() const;
void compile(const target& t); void compile(const target& t, tracer trace = tracer{});
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
......
...@@ -13,12 +13,35 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT ...@@ -13,12 +13,35 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
template <bool B> template <bool B>
using bool_c = std::integral_constant<bool, B>; using bool_c = std::integral_constant<bool, B>;
template <int N>
struct requires_enum
{
enum e
{
a = 0
};
};
#define MIGRAPH_REQUIRES_CAT(x, y) x##y
#ifdef CPPCHECK #ifdef CPPCHECK
#define MIGRAPH_REQUIRES(...) class = void #define MIGRAPH_REQUIRES(...) class = void
#else #else
#if 0
// TODO: This currently crashed on clang
#define MIGRAPH_REQUIRES(...) \ #define MIGRAPH_REQUIRES(...) \
bool PrivateRequires##__LINE__ = true, \ typename migraph::requires_enum<__LINE__>::e MIGRAPH_REQUIRES_CAT( \
class = typename std::enable_if<and_<__VA_ARGS__, PrivateRequires##__LINE__>{}>::type PrivateRequires, \
__LINE__) = migraph::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPH_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraph::requires_enum<__LINE__>::a>{}>::type
#else
#define MIGRAPH_REQUIRES(...) \
typename migraph::requires_enum<__LINE__>::e MIGRAPH_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraph::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__>{}>::type
#endif
#endif #endif
} // namespace migraph } // namespace migraph
......
...@@ -94,13 +94,13 @@ struct tensor_view ...@@ -94,13 +94,13 @@ struct tensor_view
// TODO: Add iterators so it can handle nonstandard tensors // TODO: Add iterators so it can handle nonstandard tensors
T* begin() T* begin()
{ {
assert(this->m_shape.standard()); assert(this->m_shape.standard() or this->empty());
return m_data; return m_data;
} }
T* end() T* end()
{ {
assert(this->m_shape.standard()); assert(this->m_shape.standard() or this->empty());
if(this->empty()) if(this->empty())
return m_data; return m_data;
else else
...@@ -109,13 +109,13 @@ struct tensor_view ...@@ -109,13 +109,13 @@ struct tensor_view
const T* begin() const const T* begin() const
{ {
assert(this->m_shape.standard()); assert(this->m_shape.standard() or this->empty());
return m_data; return m_data;
} }
const T* end() const const T* end() const
{ {
assert(this->m_shape.standard()); assert(this->m_shape.standard() or this->empty());
if(this->empty()) if(this->empty())
return m_data; return m_data;
else else
......
#ifndef MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#include <ostream>
namespace migraph {
struct swallow
{
template <class... Ts>
swallow(Ts&&...)
{
}
};
struct tracer
{
tracer() {}
tracer(std::ostream& s) : os(&s) {}
bool enabled() const { return os != nullptr; }
template <class... Ts>
void operator()(const Ts&... xs) const
{
if(os != nullptr)
{
swallow{*os << xs...};
*os << std::endl;
}
}
private:
std::ostream* os = nullptr;
};
} // namespace migraph
#endif
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <migraph/float_equal.hpp>
namespace migraph { namespace migraph {
// Compute the value of a range // Compute the value of a range
...@@ -101,7 +103,7 @@ auto range_distance(R1&& r1) ...@@ -101,7 +103,7 @@ auto range_distance(R1&& r1)
template <class R1> template <class R1>
bool range_zero(R1&& r1) bool range_zero(R1&& r1)
{ {
return std::all_of(r1.begin(), r1.end(), [](auto x) { return x == 0; }); return std::all_of(r1.begin(), r1.end(), [](auto x) { return float_equal(x, 0); });
} }
template <class R1, class R2, class T, class Reducer, class Product> template <class R1, class R2, class T, class Reducer, class Product>
......
...@@ -41,17 +41,22 @@ int main(int argc, char const* argv[]) ...@@ -41,17 +41,22 @@ int main(int argc, char const* argv[])
if(argc > 1) if(argc > 1)
{ {
std::string file = argv[1]; std::string file = argv[1];
auto p = migraph::parse_onnx(file);
std::cout << p << std::endl;
auto x = run_cpu(file); auto x = run_cpu(file);
auto y = run_gpu(file); auto y = run_gpu(file);
visit_all(x, y)([](auto cpu, auto gpu) { visit_all(x, y)([](auto cpu, auto gpu) {
if(migraph::verify_range(cpu, gpu)) if(migraph::verify_range(cpu, gpu, 100))
{ {
std::cout << "Passed" << std::endl; std::cout << "Passed" << std::endl;
} }
else else
{ {
std::cout << "Not equal" << std::endl; std::cout << "Not equal" << std::endl;
std::cout << "cpu:" << std::endl;
std::cout << cpu << std::endl; std::cout << cpu << std::endl;
std::cout << "gpu:" << std::endl;
std::cout << gpu << std::endl; std::cout << gpu << std::endl;
} }
......
...@@ -237,23 +237,21 @@ instruction_ref program::validate() const ...@@ -237,23 +237,21 @@ instruction_ref program::validate() const
[&](const instruction& i) { return !i.valid(impl->instructions.begin()); }); [&](const instruction& i) { return !i.valid(impl->instructions.begin()); });
} }
void program::compile(const target& t) void program::compile(const target& t, tracer trace)
{ {
assert(this->validate() == impl->instructions.end()); assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context(); this->impl->ctx = t.get_context();
if(enabled(MIGRAPH_TRACE_COMPILE{})) if(not trace.enabled() and enabled(MIGRAPH_TRACE_COMPILE{}))
std::cout << *this << std::endl << std::endl; trace = tracer{std::cout};
; trace(*this);
trace();
for(auto&& p : t.get_passes(this->impl->ctx)) for(auto&& p : t.get_passes(this->impl->ctx))
{ {
if(enabled(MIGRAPH_TRACE_COMPILE{})) trace("Pass: ", p.name());
std::cout << "Pass: " << p.name() << std::endl;
p.apply(*this); p.apply(*this);
if(enabled(MIGRAPH_TRACE_COMPILE{})) trace(*this);
std::cout << *this << std::endl;
#ifndef NDEBUG #ifndef NDEBUG
if(enabled(MIGRAPH_TRACE_COMPILE{})) trace("Validate ...");
std::cout << "Validate ..." << std::endl;
auto invalid = this->validate(); auto invalid = this->validate();
if(invalid != impl->instructions.end()) if(invalid != impl->instructions.end())
{ {
...@@ -261,8 +259,7 @@ void program::compile(const target& t) ...@@ -261,8 +259,7 @@ void program::compile(const target& t)
MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " + MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->op.name()); std::to_string(index) + ": " + invalid->op.name());
} }
if(enabled(MIGRAPH_TRACE_COMPILE{})) trace();
std::cout << std::endl;
#endif #endif
} }
auto invalid = this->validate(); auto invalid = this->validate();
...@@ -334,28 +331,36 @@ double common_average(const std::vector<double>& v) ...@@ -334,28 +331,36 @@ double common_average(const std::vector<double>& v)
void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const
{ {
using milliseconds = std::chrono::duration<double, std::milli>; using milliseconds = std::chrono::duration<double, std::milli>;
auto& ctx = this->impl->ctx;
// Run once by itself // Run once by itself
eval(params); eval(params);
ctx.finish();
// Run and time entire program // Run and time entire program
std::vector<double> total_vec; std::vector<double> total_vec;
total_vec.reserve(n); total_vec.reserve(n);
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
total_vec.push_back(time<milliseconds>([&] { eval(params); })); total_vec.push_back(time<milliseconds>([&] {
eval(params);
ctx.finish();
}));
} }
std::sort(total_vec.begin(), total_vec.end()); std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec; std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map // Fill the map
generic_eval(*this, this->impl->ctx, params, [&](auto ins, auto) { generic_eval(*this, ctx, params, [&](auto ins, auto) {
ins_vec[ins].reserve(n); ins_vec[ins].reserve(n);
return argument{}; return argument{};
}); });
// Run and time each instruction // Run and time each instruction
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
generic_eval(*this, this->impl->ctx, params, [&](auto ins, auto f) { generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result; argument result;
ins_vec[ins].push_back(time<milliseconds>([&] { result = f(); })); ins_vec[ins].push_back(time<milliseconds>([&] {
result = f();
ctx.finish();
}));
return result; return result;
}); });
} }
...@@ -366,9 +371,8 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -366,9 +371,8 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
overhead_vec.reserve(n); overhead_vec.reserve(n);
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
overhead_vec.push_back(time<milliseconds>([&] { overhead_vec.push_back(time<milliseconds>(
generic_eval(*this, this->impl->ctx, params, [](auto...) { return argument{}; }); [&] { generic_eval(*this, ctx, params, [](auto...) { return argument{}; }); }));
}));
} }
double total_time = common_average(total_vec); double total_time = common_average(total_vec);
...@@ -376,13 +380,33 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -376,13 +380,33 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
double overhead_time = common_average(overhead_vec); double overhead_time = common_average(overhead_vec);
double overhead_percent = overhead_time * 100.0 / total_time; double overhead_percent = overhead_time * 100.0 / total_time;
double total_instruction_time = 0.0; double total_instruction_time = 0.0;
std::unordered_map<std::string, double> op_times;
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
total_instruction_time += common_average(p.second); {
double avg = common_average(p.second);
op_times[p.first->op.name()] += avg;
total_instruction_time += avg;
}
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
print_program( print_program(os, *this, [&](auto ins, auto&&) {
os, *this, [&](auto ins, auto&&) { os << ": " << common_average(ins_vec[ins]) << "ms"; }); double avg = common_average(ins_vec[ins]);
double percent = std::ceil(100.0 * avg / total_instruction_time);
os << ": " << avg << "ms, " << percent << "%";
});
os << std::endl;
os << "Summary:" << std::endl;
for(auto&& p : op_times)
{
auto&& name = p.first;
double avg = p.second;
double percent = std::ceil(100.0 * avg / total_instruction_time);
os << name << ": " << avg << "ms, " << percent << "%" << std::endl;
}
os << std::endl;
os << "Rate: " << rate << "/sec" << std::endl; os << "Rate: " << rate << "/sec" << std::endl;
os << "Total time: " << total_time << "ms" << std::endl; os << "Total time: " << total_time << "ms" << std::endl;
......
...@@ -65,6 +65,7 @@ struct cpu_batch_norm_inference ...@@ -65,6 +65,7 @@ struct cpu_batch_norm_inference
dfor(num_batch, num_channels, image_height, image_width)( dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c) + epsilon) > 0);
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) / result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) /
std::sqrt(variance(c) + epsilon) + std::sqrt(variance(c) + epsilon) +
bias(c); bias(c);
...@@ -79,6 +80,7 @@ struct cpu_batch_norm_inference ...@@ -79,6 +80,7 @@ struct cpu_batch_norm_inference
dfor(num_batch, num_channels, image_height, image_width)( dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c, h, w) + epsilon) > 0);
result(n, c, h, w) = gamma(c, h, w) * result(n, c, h, w) = gamma(c, h, w) *
(buffer(n, c, h, w) - mean(c, h, w)) / (buffer(n, c, h, w) - mean(c, h, w)) /
std::sqrt(variance(c, h, w) + epsilon) + std::sqrt(variance(c, h, w) + epsilon) +
...@@ -212,6 +214,7 @@ struct cpu_contiguous ...@@ -212,6 +214,7 @@ struct cpu_contiguous
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
assert(output_shape.standard());
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
......
...@@ -8,7 +8,7 @@ namespace cpu { ...@@ -8,7 +8,7 @@ namespace cpu {
std::string cpu_target::name() const { return "cpu"; } std::string cpu_target::name() const { return "cpu"; }
std::vector<pass> cpu_target::get_passes(context&) const std::vector<pass> cpu_target::get_passes(migraph::context&) const
{ {
return {auto_contiguous{}, cpu_lowering{}}; return {auto_contiguous{}, cpu_lowering{}};
} }
......
#ifndef MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
namespace migraph {
namespace cpu {
struct context
{
void finish() const {}
};
} // namespace cpu
} // namespace migraph
#endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_CPU_TARGET_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_CPU_TARGET_HPP
#include <migraph/program.hpp> #include <migraph/program.hpp>
#include <migraph/cpu/context.hpp>
namespace migraph { namespace migraph {
namespace cpu { namespace cpu {
...@@ -9,8 +10,8 @@ namespace cpu { ...@@ -9,8 +10,8 @@ namespace cpu {
struct cpu_target struct cpu_target
{ {
std::string name() const; std::string name() const;
std::vector<pass> get_passes(context& ctx) const; std::vector<pass> get_passes(migraph::context& ctx) const;
context get_context() const { return {}; } migraph::context get_context() const { return context{}; }
}; };
} // namespace cpu } // namespace cpu
......
...@@ -18,6 +18,7 @@ target_link_libraries(migraph_device migraph hip::device) ...@@ -18,6 +18,7 @@ target_link_libraries(migraph_device migraph hip::device)
target_include_directories(migraph_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraph_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
add_library(migraph_gpu add_library(migraph_gpu
eliminate_allocation.cpp
eliminate_workspace.cpp eliminate_workspace.cpp
hip.cpp hip.cpp
target.cpp target.cpp
......
#include <migraph/gpu/eliminate_allocation.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp>
namespace migraph {
namespace gpu {
void eliminate_allocation::apply(program& p) const
{
std::size_t n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p))
{
if(ins->op.name() != "hip::allocate")
continue;
allocs.emplace_back(ins, n);
std::size_t size = ins->get_shape().bytes();
n += size + (size % 4);
}
auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}});
for(auto&& pp : allocs)
{
auto ins = pp.first;
auto s = ins->get_shape();
auto offset = pp.second;
p.replace_instruction(ins, hip_load{s, offset}, mem);
}
}
} // namespace gpu
} // namespace migraph
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment