Commit 3d08f40f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into scatter-op

parents 50cfbcda 6ba279cc
...@@ -37,6 +37,7 @@ add_library(migraphx ...@@ -37,6 +37,7 @@ add_library(migraphx
msgpack.cpp msgpack.cpp
operation.cpp operation.cpp
permutation.cpp permutation.cpp
preallocate_param.cpp
process.cpp process.cpp
program.cpp program.cpp
module.cpp module.cpp
......
...@@ -124,6 +124,14 @@ argument::data_t argument::data_t::from_args(const std::vector<argument>& args) ...@@ -124,6 +124,14 @@ argument::data_t argument::data_t::from_args(const std::vector<argument>& args)
return result; return result;
} }
argument argument::copy() const
{
argument result{this->get_shape()};
auto* src = this->data();
std::copy(src, src + this->get_shape().bytes(), result.data());
return result;
}
argument argument::share() const { return {m_shape, m_data.share()}; } argument argument::share() const { return {m_shape, m_data.share()}; }
std::vector<argument> argument::get_sub_objects() const std::vector<argument> argument::get_sub_objects() const
......
...@@ -26,6 +26,8 @@ struct allocation_model ...@@ -26,6 +26,8 @@ struct allocation_model
std::string copy() const; std::string copy() const;
/// Create an allocation operator for the given shape /// Create an allocation operator for the given shape
operation allocate(const shape& s) const; operation allocate(const shape& s) const;
/// Create a preallocated operator for the given shape
operation preallocate(const shape& s, const std::string& id) const;
}; };
#else #else
...@@ -38,6 +40,7 @@ struct allocation_model ...@@ -38,6 +40,7 @@ struct allocation_model
* std::string name() const; * std::string name() const;
* std::string copy() const; * std::string copy() const;
* operation allocate(const shape& s) const; * operation allocate(const shape& s) const;
* operation preallocate(const shape& s,std::string id) const;
* }; * };
* *
*/ */
...@@ -123,6 +126,12 @@ struct allocation_model ...@@ -123,6 +126,12 @@ struct allocation_model
return (*this).private_detail_te_get_handle().allocate(s); return (*this).private_detail_te_get_handle().allocate(s);
} }
operation preallocate(const shape& s, std::string id) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().preallocate(s, std::move(id));
}
friend bool is_shared(const allocation_model& private_detail_x, friend bool is_shared(const allocation_model& private_detail_x,
const allocation_model& private_detail_y) const allocation_model& private_detail_y)
{ {
...@@ -137,9 +146,10 @@ struct allocation_model ...@@ -137,9 +146,10 @@ struct allocation_model
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual std::string copy() const = 0; virtual std::string copy() const = 0;
virtual operation allocate(const shape& s) const = 0; virtual operation allocate(const shape& s) const = 0;
virtual operation preallocate(const shape& s, std::string id) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -180,6 +190,12 @@ struct allocation_model ...@@ -180,6 +190,12 @@ struct allocation_model
return private_detail_te_value.allocate(s); return private_detail_te_value.allocate(s);
} }
operation preallocate(const shape& s, std::string id) const override
{
return private_detail_te_value.preallocate(s, std::move(id));
}
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -60,6 +60,8 @@ struct argument : raw_data<argument> ...@@ -60,6 +60,8 @@ struct argument : raw_data<argument>
argument reshape(const shape& s) const; argument reshape(const shape& s) const;
argument copy() const;
/// Make copy of the argument that is always sharing the data /// Make copy of the argument that is always sharing the data
argument share() const; argument share() const;
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class lifetime
{
local,
global,
borrow
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -35,7 +36,7 @@ struct as_shape ...@@ -35,7 +36,7 @@ struct as_shape
{ {
return args.front().reshape(output_shape); return args.front().reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -66,7 +67,7 @@ struct broadcast ...@@ -66,7 +67,7 @@ struct broadcast
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -50,7 +51,7 @@ struct flatten ...@@ -50,7 +51,7 @@ struct flatten
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -36,7 +37,7 @@ struct load ...@@ -36,7 +37,7 @@ struct load
MIGRAPHX_THROW("Load access is out of bounds"); MIGRAPHX_THROW("Load access is out of bounds");
return argument::load(s, args[0].data() + offset); return argument::load(s, args[0].data() + offset);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op) friend std::ostream& operator<<(std::ostream& os, const load& op)
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -68,7 +69,7 @@ struct multibroadcast ...@@ -68,7 +69,7 @@ struct multibroadcast
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -43,7 +43,7 @@ struct prefix_scan_op : op_name<Derived> ...@@ -43,7 +43,7 @@ struct prefix_scan_op : op_name<Derived>
argument compute(const shape&, std::vector<argument> args) const argument compute(const shape&, std::vector<argument> args) const
{ {
argument result = args[0]; argument result = args[0].copy();
auto s = result.get_shape(); auto s = result.get_shape();
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}}; auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens(); auto lens = s.lens();
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -71,7 +72,7 @@ struct reshape ...@@ -71,7 +72,7 @@ struct reshape
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -39,7 +40,7 @@ struct scalar ...@@ -39,7 +40,7 @@ struct scalar
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -77,7 +78,7 @@ struct squeeze ...@@ -77,7 +78,7 @@ struct squeeze
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -71,7 +72,7 @@ struct step ...@@ -71,7 +72,7 @@ struct step
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -63,7 +64,7 @@ struct transpose ...@@ -63,7 +64,7 @@ struct transpose
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -70,7 +71,7 @@ struct unsqueeze ...@@ -70,7 +71,7 @@ struct unsqueeze
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <migraphx/module_ref.hpp> #include <migraphx/module_ref.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -435,9 +436,9 @@ void from_value_op(T& x, const value& v) ...@@ -435,9 +436,9 @@ void from_value_op(T& x, const value& v)
} }
template <class T> template <class T>
bool is_borrowed_op(const T&) lifetime get_lifetime_op(const T&)
{ {
return false; return lifetime::local;
} }
} // namespace detail } // namespace detail
...@@ -451,7 +452,7 @@ bool is_borrowed_op(const T&) ...@@ -451,7 +452,7 @@ bool is_borrowed_op(const T&)
* bool is_context_free() const; * bool is_context_free() const;
* bool need_normalization() const; * bool need_normalization() const;
* bool has_finalize() const; * bool has_finalize() const;
* bool is_borrowed() const; * lifetime get_lifetime() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const; * std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ; * value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; * void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
...@@ -559,10 +560,10 @@ struct operation ...@@ -559,10 +560,10 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize(); return (*this).private_detail_te_get_handle().has_finalize();
} }
bool is_borrowed() const lifetime get_lifetime() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_borrowed(); return (*this).private_detail_te_get_handle().get_lifetime();
} }
std::ptrdiff_t output_alias(const std::vector<shape>& input) const std::ptrdiff_t output_alias(const std::vector<shape>& input) const
...@@ -678,7 +679,7 @@ struct operation ...@@ -678,7 +679,7 @@ struct operation
virtual bool is_context_free() const = 0; virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0; virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0; virtual bool has_finalize() const = 0;
virtual bool is_borrowed() const = 0; virtual lifetime get_lifetime() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0; virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual value virtual value
compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0; compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
...@@ -750,16 +751,16 @@ struct operation ...@@ -750,16 +751,16 @@ struct operation
} }
template <class T> template <class T>
static auto private_detail_te_default_is_borrowed(char, T&& private_detail_te_self) static auto private_detail_te_default_get_lifetime(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_borrowed()) -> decltype(private_detail_te_self.get_lifetime())
{ {
return private_detail_te_self.is_borrowed(); return private_detail_te_self.get_lifetime();
} }
template <class T> template <class T>
static bool private_detail_te_default_is_borrowed(float, T&& private_detail_te_self) static lifetime private_detail_te_default_get_lifetime(float, T&& private_detail_te_self)
{ {
return detail::is_borrowed_op(private_detail_te_self); return detail::get_lifetime_op(private_detail_te_self);
} }
template <class T> template <class T>
...@@ -1044,10 +1045,10 @@ struct operation ...@@ -1044,10 +1045,10 @@ struct operation
return private_detail_te_default_has_finalize(char(0), private_detail_te_value); return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
} }
bool is_borrowed() const override lifetime get_lifetime() const override
{ {
return private_detail_te_default_is_borrowed(char(0), private_detail_te_value); return private_detail_te_default_get_lifetime(char(0), private_detail_te_value);
} }
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
......
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_PREALLOCATE_PARAM_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_PREALLOCATE_PARAM_HPP #define MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/allocation_model.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu { struct module;
struct preallocate_param struct preallocate_param
{ {
std::string param{}; std::string param;
context* ctx = nullptr; allocation_model model;
std::string name() const { return "preallocate_param"; } std::string name() const { return "preallocate_param"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#endif
...@@ -440,17 +440,19 @@ bool is_borrowed(instruction_ref ins) ...@@ -440,17 +440,19 @@ bool is_borrowed(instruction_ref ins)
auto alias = instruction::get_output_alias(ins, true); auto alias = instruction::get_output_alias(ins, true);
if(alias == ins) if(alias == ins)
return false; return false;
if(alias->get_operator().is_borrowed()) lifetime l = alias->get_operator().get_lifetime();
if(l == lifetime::borrow)
return true; return true;
return is_borrowed(alias); return is_borrowed(alias);
} }
bool is_param_alias(instruction_ref ins) bool is_global(instruction_ref ins)
{ {
return instruction::get_output_alias(ins)->name() == "@param"; const auto& op = instruction::get_output_alias(ins)->get_operator();
return op.name() == "@param" or op.get_lifetime() == lifetime::global;
} }
bool is_dangling(instruction_ref ins) { return not is_param_alias(ins) and is_borrowed(ins); } bool is_dangling(instruction_ref ins) { return not is_global(ins) and is_borrowed(ins); }
instruction_ref module::find_dangling_reference() const instruction_ref module::find_dangling_reference() const
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment