Commit fa4cf4c4 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into pad_op

parents e067dce0 31b2c735
...@@ -15,11 +15,19 @@ struct check_context ...@@ -15,11 +15,19 @@ struct check_context
std::string name() const { return "check_context"; } std::string name() const { return "check_context"; }
shape compute_shape(const std::vector<shape>&) const { return {}; } shape compute_shape(const std::vector<shape>&) const { return {}; }
argument compute(context& ctx, const shape&, const std::vector<argument>&) const argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
this->check(ctx);
return {};
}
void finalize(context& ctx, const shape&, const std::vector<shape>&) const
{
this->check(ctx);
}
void check(context& ctx) const
{ {
T* x = any_cast<T>(&ctx); T* x = any_cast<T>(&ctx);
if(x == nullptr) if(x == nullptr)
MIGRAPHX_THROW(std::string("Unexpected context type: ") + ctx.type_id().name()); MIGRAPHX_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
} }
}; };
......
...@@ -119,6 +119,13 @@ struct concat_optimization ...@@ -119,6 +119,13 @@ struct concat_optimization
return (*this).private_detail_te_get_handle().get_concat(op); return (*this).private_detail_te_get_handle().get_concat(op);
} }
friend bool is_shared(const concat_optimization& private_detail_x,
const concat_optimization& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
......
...@@ -95,7 +95,13 @@ struct context ...@@ -95,7 +95,13 @@ struct context
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().finish(); (*this).private_detail_te_get_handle().finish();
}
friend bool is_shared(const context& private_detail_x, const context& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
} }
private: private:
...@@ -136,7 +142,7 @@ struct context ...@@ -136,7 +142,7 @@ struct context
const std::type_info& type() const override { return typeid(private_detail_te_value); } const std::type_info& type() const override { return typeid(private_detail_te_value); }
void finish() const override { return private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -94,6 +94,12 @@ constexpr void each_args(F) ...@@ -94,6 +94,12 @@ constexpr void each_args(F)
{ {
} }
template <class F, class T>
auto unpack(F f, T& x)
{
return sequence_c<std::tuple_size<T>{}>([&](auto... is) { f(std::get<is>(x)...); });
}
/// Implements a fix-point combinator /// Implements a fix-point combinator
template <class R, class F> template <class R, class F>
detail::fix_f<R, F> fix(F f) detail::fix_f<R, F> fix(F f)
......
...@@ -14,6 +14,7 @@ namespace migraphx { ...@@ -14,6 +14,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args); shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
struct instruction struct instruction
{ {
...@@ -73,6 +74,8 @@ struct instruction ...@@ -73,6 +74,8 @@ struct instruction
argument eval() const; argument eval() const;
void finalize(context& ctx);
static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false); static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false);
private: private:
......
...@@ -26,6 +26,8 @@ struct operation ...@@ -26,6 +26,8 @@ struct operation
{ {
/// A unique name identifying the operation /// A unique name identifying the operation
std::string name() const; std::string name() const;
/// An optional method that can be used to finalize the operator before running
void finalize(context& ctx);
/// This is used to compute the resulting shape from an operation. If an /// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an /// operation cannot be run with input shapes, then it should throw an
/// exception. /// exception.
...@@ -55,6 +57,8 @@ struct operation ...@@ -55,6 +57,8 @@ struct operation
/// Returns true if operation does not require a context to run compute /// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x); bool is_context_free(const operation& x);
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
#else #else
...@@ -189,6 +193,44 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -189,6 +193,44 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return output_alias_op(rank<1>{}, x, shapes); return output_alias_op(rank<1>{}, x, shapes);
} }
template <class T>
auto finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
x.finalize(auto_any_cast(ctx), output_shape, input);
}
template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
}
template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
finalize_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto has_finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});
template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
-> std::false_type;
template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
std::declval<T&>(),
std::declval<context&>(),
std::declval<const shape&>(),
std::declval<std::vector<shape>>()))
{
return {};
}
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
...@@ -196,7 +238,9 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -196,7 +238,9 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
* { * {
* std::string name() const; * std::string name() const;
* bool is_context_free() const; * bool is_context_free() const;
* bool has_finalize() const;
* int output_alias(const std::vector<shape>& input) const; * int output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const; * argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const; * argument compute(const shape& output,const std::vector<argument>& input) const;
...@@ -275,12 +319,24 @@ struct operation ...@@ -275,12 +319,24 @@ struct operation
return (*this).private_detail_te_get_handle().is_context_free(); return (*this).private_detail_te_get_handle().is_context_free();
} }
bool has_finalize() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().has_finalize();
}
int output_alias(const std::vector<shape>& input) const int output_alias(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input); return (*this).private_detail_te_get_handle().output_alias(input);
} }
void finalize(context& ctx, const shape& output, const std::vector<shape>& input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finalize(ctx, output, input);
}
shape compute_shape(const std::vector<shape>& input) const shape compute_shape(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -311,6 +367,12 @@ struct operation ...@@ -311,6 +367,12 @@ struct operation
return x.private_detail_te_get_handle().operator==(y); return x.private_detail_te_get_handle().operator==(y);
} }
friend bool is_shared(const operation& private_detail_x, const operation& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -320,7 +382,10 @@ struct operation ...@@ -320,7 +382,10 @@ struct operation
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual bool is_context_free() const = 0; virtual bool is_context_free() const = 0;
virtual bool has_finalize() const = 0;
virtual int output_alias(const std::vector<shape>& input) const = 0; virtual int output_alias(const std::vector<shape>& input) const = 0;
virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0; virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual argument virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
...@@ -365,12 +430,20 @@ struct operation ...@@ -365,12 +430,20 @@ struct operation
return is_context_free_op(private_detail_te_value); return is_context_free_op(private_detail_te_value);
} }
bool has_finalize() const override { return has_finalize_op(private_detail_te_value); }
int output_alias(const std::vector<shape>& input) const override int output_alias(const std::vector<shape>& input) const override
{ {
return output_alias_op(private_detail_te_value, input); return output_alias_op(private_detail_te_value, input);
} }
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
{
finalize_op(private_detail_te_value, ctx, output, input);
}
shape compute_shape(const std::vector<shape>& input) const override shape compute_shape(const std::vector<shape>& input) const override
{ {
...@@ -478,6 +551,14 @@ bool is_context_free(const T& x) ...@@ -478,6 +551,14 @@ bool is_context_free(const T& x)
return is_context_free_op(x); return is_context_free_op(x);
} }
inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T>
bool has_finalize(const T& x)
{
return has_finalize_op(x);
}
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
#include <array>
#include <numeric>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class... Ts>
auto par_dfor(Ts... xs)
{
return [=](auto f) {
using array_type = std::array<std::size_t, sizeof...(Ts)>;
array_type lens = {{static_cast<std::size_t>(xs)...}};
auto n = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>{});
const std::size_t min_grain = 8;
if(n > 2 * min_grain)
{
array_type strides;
strides.fill(1);
std::partial_sum(lens.rbegin(),
lens.rend() - 1,
strides.rbegin() + 1,
std::multiplies<std::size_t>());
auto size =
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>());
par_for(size, min_grain, [&](std::size_t i) {
array_type indices;
std::transform(strides.begin(),
strides.end(),
lens.begin(),
indices.begin(),
[&](size_t stride, size_t len) { return (i / stride) % len; });
migraphx::unpack(f, indices);
});
}
else
{
dfor(xs...)(f);
}
};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
void par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
f(i);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::generate(threads.begin(), threads.end(), [=, &work] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
f(i);
}
});
work += grainsize;
return result;
});
assert(work >= n);
}
}
template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize =
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain);
par_for_impl(n, threadsize, f);
}
template <class F>
void par_for(std::size_t n, F f)
{
const int min_grain = 8;
par_for(n, min_grain, f);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -105,7 +105,13 @@ struct pass ...@@ -105,7 +105,13 @@ struct pass
void apply(program& p) const void apply(program& p) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().apply(p); (*this).private_detail_te_get_handle().apply(p);
}
friend bool is_shared(const pass& private_detail_x, const pass& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
} }
private: private:
...@@ -149,7 +155,7 @@ struct pass ...@@ -149,7 +155,7 @@ struct pass
std::string name() const override { return private_detail_te_value.name(); } std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { return private_detail_te_value.apply(p); } void apply(program& p) const override { private_detail_te_value.apply(p); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -91,10 +91,14 @@ struct program ...@@ -91,10 +91,14 @@ struct program
shape get_shape() const; shape get_shape() const;
context& get_context() const;
instruction_ref validate() const; instruction_ref validate() const;
void compile(const target& t, tracer trace = tracer{}); void compile(const target& t, tracer trace = tracer{});
void finalize();
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void debug_print() const; void debug_print() const;
......
...@@ -127,6 +127,12 @@ struct target ...@@ -127,6 +127,12 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
friend bool is_shared(const target& private_detail_x, const target& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
......
...@@ -162,14 +162,6 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) ...@@ -162,14 +162,6 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this); old->remove_output(*this);
} }
std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
}
argument instruction::eval() const argument instruction::eval() const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
...@@ -191,9 +183,15 @@ argument instruction::eval() const ...@@ -191,9 +183,15 @@ argument instruction::eval() const
return {}; return {};
} }
void instruction::finalize(context& ctx)
{
if(has_finalize(this->op))
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow) instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
{ {
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs())); auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
if(i < 0) if(i < 0)
return ins; return ins;
if(shallow) if(shallow)
...@@ -201,9 +199,17 @@ instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow) ...@@ -201,9 +199,17 @@ instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
return get_output_alias(ins->inputs().at(i)); return get_output_alias(ins->inputs().at(i));
} }
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args) shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{ {
return op.compute_shape(compute_shapes(args)); return op.compute_shape(to_shapes(args));
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -271,6 +271,8 @@ instruction_ref program::end() const { return impl->instructions.end(); } ...@@ -271,6 +271,8 @@ instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().get_shape(); } shape program::get_shape() const { return impl->instructions.back().get_shape(); }
context& program::get_context() const { return impl->ctx; }
instruction_ref program::validate() const instruction_ref program::validate() const
{ {
return std::find_if(impl->instructions.begin(), return std::find_if(impl->instructions.begin(),
...@@ -309,6 +311,15 @@ void program::compile(const target& t, tracer trace) ...@@ -309,6 +311,15 @@ void program::compile(const target& t, tracer trace)
auto index = std::distance(impl->instructions.begin(), invalid); auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index)); MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index));
} }
this->finalize();
}
void program::finalize()
{
for(auto ins : iterator_for(*this))
{
ins->finalize(this->impl->ctx);
}
} }
template <class F> template <class F>
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/cpu/gemm.hpp> #include <migraphx/cpu/gemm.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -72,7 +73,7 @@ struct cpu_batch_norm_inference ...@@ -72,7 +73,7 @@ struct cpu_batch_norm_inference
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)( visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)( par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c) + epsilon) > 0); assert((variance(c) + epsilon) > 0);
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) / result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) /
...@@ -87,7 +88,7 @@ struct cpu_batch_norm_inference ...@@ -87,7 +88,7 @@ struct cpu_batch_norm_inference
visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)( visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)( par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c, h, w) + epsilon) > 0); assert((variance(c, h, w) + epsilon) > 0);
result(n, c, h, w) = gamma(c, h, w) * result(n, c, h, w) = gamma(c, h, w) *
...@@ -122,7 +123,7 @@ struct cpu_convolution ...@@ -122,7 +123,7 @@ struct cpu_convolution
auto wei_h = wei[2]; auto wei_h = wei[2];
auto wei_w = wei[3]; auto wei_w = wei[3];
dfor(output_shape.lens()[0], par_dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
...@@ -245,7 +246,7 @@ struct cpu_pooling ...@@ -245,7 +246,7 @@ struct cpu_pooling
auto in_h = input.get_shape().lens()[2]; auto in_h = input.get_shape().lens()[2];
auto in_w = input.get_shape().lens()[3]; auto in_w = input.get_shape().lens()[3];
dfor(output_shape.lens()[0], par_dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
......
...@@ -41,11 +41,11 @@ argument miopen_convolution::compute(context& ctx, ...@@ -41,11 +41,11 @@ argument miopen_convolution::compute(context& ctx,
shape miopen_convolution::compile(context& ctx, shape miopen_convolution::compile(context& ctx,
const shape& output_shape, const shape& output_shape,
std::vector<instruction_ref> inputs) std::vector<shape> inputs)
{ {
shape workspace_shape{}; shape workspace_shape{};
auto x_desc = make_tensor(inputs[0]->get_shape()); auto x_desc = make_tensor(inputs[0]);
auto w_desc = make_tensor(inputs[1]->get_shape()); auto w_desc = make_tensor(inputs[1]);
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0; std::size_t workspace_size = 0;
...@@ -57,8 +57,8 @@ shape miopen_convolution::compile(context& ctx, ...@@ -57,8 +57,8 @@ shape miopen_convolution::compile(context& ctx,
&workspace_size); &workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x = to_gpu(generate_argument(inputs[0]->get_shape())); auto x = to_gpu(generate_argument(inputs[0]));
auto w = to_gpu(generate_argument(inputs[1]->get_shape())); auto w = to_gpu(generate_argument(inputs[1]));
auto y = allocate_gpu(output_shape); auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape); auto workspace = allocate_gpu(workspace_shape);
...@@ -80,10 +80,21 @@ shape miopen_convolution::compile(context& ctx, ...@@ -80,10 +80,21 @@ shape miopen_convolution::compile(context& ctx,
false); false);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Find convolution failed"); MIGRAPHX_THROW("Find convolution failed");
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo; algo = perf.fwd_algo;
return shape{shape::int8_type, {perf.memory}}; return shape{shape::int8_type, {perf.memory}};
} }
void miopen_convolution::finalize(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
if(handle == ctx.get_stream().get_miopen())
return;
// TODO: Check that workspace hasn't changed
compile(ctx, output_shape, std::move(inputs));
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -274,11 +274,8 @@ struct miopen_conv_bias ...@@ -274,11 +274,8 @@ struct miopen_conv_bias
return f.execute(ctx, fargs, args[0], args[4]); return f.execute(ctx, fargs, args[0], args[4]);
} }
shape compile(context& ctx) void finalize(context& ctx, const shape&, const std::vector<shape>&) { f.compile(ctx); }
{ shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
f.compile(ctx);
return f.get_workspace(ctx);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
...@@ -318,12 +315,8 @@ struct miopen_conv_bias_relu ...@@ -318,12 +315,8 @@ struct miopen_conv_bias_relu
miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0); miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
return f.execute(ctx, fargs, args[0], args[4]); return f.execute(ctx, fargs, args[0], args[4]);
} }
void finalize(context& ctx, const shape&, const std::vector<shape>&) { f.compile(ctx); }
shape compile(context& ctx) shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
{
f.compile(ctx);
return f.get_workspace(ctx);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
...@@ -350,7 +343,7 @@ void apply_conv_bias(context& ctx, program& p, match::matcher_result r) ...@@ -350,7 +343,7 @@ void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()}; Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation // TODO: Insert ws allocation
auto ws = cb.compile(ctx); auto ws = cb.get_workspace(ctx);
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
......
...@@ -27,6 +27,7 @@ struct miopen_convolution ...@@ -27,6 +27,7 @@ struct miopen_convolution
op::convolution op; op::convolution op;
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -39,7 +40,8 @@ struct miopen_convolution ...@@ -39,7 +40,8 @@ struct miopen_convolution
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
shape compile(context& ctx, const shape& output_shape, std::vector<instruction_ref> inputs); shape compile(context& ctx, const shape& output_shape, std::vector<shape> inputs);
void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs);
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
......
...@@ -132,7 +132,7 @@ struct miopen_apply ...@@ -132,7 +132,7 @@ struct miopen_apply
auto&& op = any_cast<op::convolution>(ins->get_operator()); auto&& op = any_cast<op::convolution>(ins->get_operator());
auto conv = miopen_convolution{op, make_conv(op)}; auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), ins->inputs()); auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
......
...@@ -8,9 +8,57 @@ ...@@ -8,9 +8,57 @@
struct id_target struct id_target
{ {
struct context
{
void finish() const {}
};
migraphx::context ctx = context{};
std::string name() const { return "id"; } std::string name() const { return "id"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {}; } std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {}; }
migraphx::context get_context() const { return {}; } migraphx::context get_context() const { return ctx; }
};
struct id_ctx_op
{
std::string name() const { return "id_ctx_op"; }
migraphx::argument
compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};
struct id_ctx_final_op
{
std::string name() const { return "id_ctx_final_op"; }
migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
void finalize(id_target::context&, const migraphx::shape&, const std::vector<migraphx::shape>&)
{
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
}; };
struct reverse_pass struct reverse_pass
...@@ -224,4 +272,52 @@ TEST_CASE(double_reverse_target_test) ...@@ -224,4 +272,52 @@ TEST_CASE(double_reverse_target_test)
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
} }
// Check that the program doesnt modify the context directly, and only the operators modify the
// context
TEST_CASE(eval_context1)
{
migraphx::program p;
id_target t{};
EXPECT(is_shared(t.ctx, t.get_context()));
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({});
EXPECT(is_shared(t.ctx, p.get_context()));
}
TEST_CASE(eval_context2)
{
migraphx::program p;
id_target t{};
EXPECT(is_shared(t.ctx, t.get_context()));
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(id_ctx_op{}, one, two);
p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({});
// id_ctx_op will modify the context
EXPECT(not is_shared(t.ctx, p.get_context()));
}
TEST_CASE(eval_context3)
{
migraphx::program p;
id_target t{};
EXPECT(is_shared(t.ctx, t.get_context()));
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(id_ctx_final_op{}, one, two);
p.compile(t);
// Finalizer will modify the context
EXPECT(not is_shared(t.ctx, p.get_context()));
auto ctx = p.get_context();
p.eval({});
EXPECT(is_shared(ctx, p.get_context()));
EXPECT(not is_shared(t.ctx, p.get_context()));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -224,7 +224,13 @@ inline void run(int argc, const char* argv[]) ...@@ -224,7 +224,13 @@ inline void run(int argc, const char* argv[])
std::unordered_map<std::string, std::function<void()>> m(get_test_cases().begin(), std::unordered_map<std::string, std::function<void()>> m(get_test_cases().begin(),
get_test_cases().end()); get_test_cases().end());
for(auto&& name : cases) for(auto&& name : cases)
run_test_case(name, m[name]); {
auto f = m.find(name);
if(f == m.end())
std::cout << "[ ERROR ] Test case '" << name << "' not found." << std::endl;
else
run_test_case(name, f->second);
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment