Unverified Commit 427fc25c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Preallocate parameters on the CPU and unify preallocations (#840)



* Add preallocate method

* Add preallocate_param pass

* Preallocate buffers on the cpu

* Formatting

* Preallocate on the gpu

* Add missing cpp file

* Formatting

* Add lifetime function

* Formatting

* Always allocate

* Fix tidy warning

* Add const

* Add missing lifetime annotations
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent f60c3815
......@@ -37,6 +37,7 @@ add_library(migraphx
msgpack.cpp
operation.cpp
permutation.cpp
preallocate_param.cpp
process.cpp
program.cpp
module.cpp
......
......@@ -26,6 +26,8 @@ struct allocation_model
std::string copy() const;
/// Create an allocation operator for the given shape
operation allocate(const shape& s) const;
/// Create a preallocated operator for the given shape
operation preallocate(const shape& s, const std::string& id) const;
};
#else
......@@ -38,6 +40,7 @@ struct allocation_model
* std::string name() const;
* std::string copy() const;
* operation allocate(const shape& s) const;
* operation preallocate(const shape& s,std::string id) const;
* };
*
*/
......@@ -123,6 +126,12 @@ struct allocation_model
return (*this).private_detail_te_get_handle().allocate(s);
}
operation preallocate(const shape& s, std::string id) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().preallocate(s, std::move(id));
}
friend bool is_shared(const allocation_model& private_detail_x,
const allocation_model& private_detail_y)
{
......@@ -137,9 +146,10 @@ struct allocation_model
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual std::string copy() const = 0;
virtual operation allocate(const shape& s) const = 0;
virtual std::string name() const = 0;
virtual std::string copy() const = 0;
virtual operation allocate(const shape& s) const = 0;
virtual operation preallocate(const shape& s, std::string id) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -180,6 +190,12 @@ struct allocation_model
return private_detail_te_value.allocate(s);
}
operation preallocate(const shape& s, std::string id) const override
{
return private_detail_te_value.preallocate(s, std::move(id));
}
PrivateDetailTypeErasedT private_detail_te_value;
};
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class lifetime
{
local,
global,
borrow
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -35,7 +36,7 @@ struct as_shape
{
return args.front().reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -66,7 +67,7 @@ struct broadcast
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -10,6 +10,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -50,7 +51,7 @@ struct flatten
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -6,6 +6,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -36,7 +37,7 @@ struct load
MIGRAPHX_THROW("Load access is out of bounds");
return argument::load(s, args[0].data() + offset);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op)
......
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -68,7 +69,7 @@ struct multibroadcast
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -71,7 +72,7 @@ struct reshape
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -39,7 +40,7 @@ struct scalar
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -9,6 +9,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -77,7 +78,7 @@ struct squeeze
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -7,6 +7,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -71,7 +72,7 @@ struct step
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -63,7 +64,7 @@ struct transpose
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -70,7 +71,7 @@ struct unsqueeze
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -15,6 +15,7 @@
#include <migraphx/module_ref.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......@@ -435,9 +436,9 @@ void from_value_op(T& x, const value& v)
}
template <class T>
bool is_borrowed_op(const T&)
lifetime get_lifetime_op(const T&)
{
return false;
return lifetime::local;
}
} // namespace detail
......@@ -451,7 +452,7 @@ bool is_borrowed_op(const T&)
* bool is_context_free() const;
* bool need_normalization() const;
* bool has_finalize() const;
* bool is_borrowed() const;
* lifetime get_lifetime() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
......@@ -559,10 +560,10 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize();
}
bool is_borrowed() const
lifetime get_lifetime() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_borrowed();
return (*this).private_detail_te_get_handle().get_lifetime();
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const
......@@ -678,7 +679,7 @@ struct operation
virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0;
virtual bool is_borrowed() const = 0;
virtual lifetime get_lifetime() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual value
compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
......@@ -750,16 +751,16 @@ struct operation
}
template <class T>
static auto private_detail_te_default_is_borrowed(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_borrowed())
static auto private_detail_te_default_get_lifetime(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.get_lifetime())
{
return private_detail_te_self.is_borrowed();
return private_detail_te_self.get_lifetime();
}
template <class T>
static bool private_detail_te_default_is_borrowed(float, T&& private_detail_te_self)
static lifetime private_detail_te_default_get_lifetime(float, T&& private_detail_te_self)
{
return detail::is_borrowed_op(private_detail_te_self);
return detail::get_lifetime_op(private_detail_te_self);
}
template <class T>
......@@ -1044,10 +1045,10 @@ struct operation
return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
}
bool is_borrowed() const override
lifetime get_lifetime() const override
{
return private_detail_te_default_is_borrowed(char(0), private_detail_te_value);
return private_detail_te_default_get_lifetime(char(0), private_detail_te_value);
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
......
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_PREALLOCATE_PARAM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_PREALLOCATE_PARAM_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
#include <migraphx/allocation_model.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct module;
struct preallocate_param
{
std::string param{};
context* ctx = nullptr;
std::string param;
allocation_model model;
std::string name() const { return "preallocate_param"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#endif // MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
......@@ -440,17 +440,19 @@ bool is_borrowed(instruction_ref ins)
auto alias = instruction::get_output_alias(ins, true);
if(alias == ins)
return false;
if(alias->get_operator().is_borrowed())
lifetime l = alias->get_operator().get_lifetime();
if(l == lifetime::borrow)
return true;
return is_borrowed(alias);
}
bool is_param_alias(instruction_ref ins)
bool is_global(instruction_ref ins)
{
return instruction::get_output_alias(ins)->name() == "@param";
const auto& op = instruction::get_output_alias(ins)->get_operator();
return op.name() == "@param" or op.get_lifetime() == lifetime::global;
}
bool is_dangling(instruction_ref ins) { return not is_param_alias(ins) and is_borrowed(ins); }
bool is_dangling(instruction_ref ins) { return not is_global(ins) and is_borrowed(ins); }
instruction_ref module::find_dangling_reference() const
{
......
#include <migraphx/gpu/preallocate_param.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/program.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void preallocate_param::apply(module& p) const
void preallocate_param::apply(module& m) const
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
if(ins->name() != "@param")
continue;
if(param != any_cast<builtin::param>(ins->get_operator()).parameter)
continue;
std::string id = p.name() + ":" + param;
auto r = p.insert_instruction(ins, hip_allocate_memory{ins->get_shape(), id});
p.replace_instruction(ins, r);
std::string id = m.name() + ":" + param;
auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id));
m.replace_instruction(ins, r);
}
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -19,6 +19,7 @@ add_library(migraphx_cpu
logsoftmax.cpp
lowering.cpp
lrn.cpp
preallocate.cpp
pooling.cpp
reduction.cpp
reorder.cpp
......
......@@ -11,6 +11,11 @@ operation cpu_allocation_model::allocate(const shape& s) const
return make_op(name(), {{"shape", to_value(s)}});
}
operation cpu_allocation_model::preallocate(const shape& s, const std::string& id) const
{
return make_op("cpu::preallocate", {{"shape", to_value(s)}, {"id", id}});
}
std::string cpu_allocation_model::copy() const { return "cpu::copy"; }
} // namespace cpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment