"...composable_kernel_rocm.git" did not exist on "c3a4652a6857df8ab26f5c9fad5f68c99b36f7c4"
Unverified Commit 406afeb8 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Use dnnl for cpu backend (#688)



* Add flag to enable cpu backend

* Make buffers shared

* Enable optimizations

* Add onednn

* Formatting

* Formatting

* Add dnnl header

* Formatting

* Rewrite rnn first

* Formatting

* Call reference implementation

* Formatting

* Make literal data shared

* Formatting

* Add convolution

* Formatting

* Compensate for dilation

* Formatting

* Use name/make_op instead

* Formatting

* Rename gemm header

* Formatting

* Add dnnl convolution/gemm operators

* Formatting

* Add eliminate_contiguous

* Add faster pointwise operators

* Formatting

* Formatting

* Formatting

* Add dnnl op class

* Formatting

* Add add op

* Formatting

* Add concat operator

* Formatting

* Add more ops

* Create descriptor during finalization

* Formatting

* Dont rewrite pooling

* Enable memory coloring

* Formatting

* Add output aliases

* Formatting

* Fix errors

* Formatting

* Convert literals

* Add missing file

* Remove batch_norm

* Formatting

* Use strides

* Formatting

* Add some debug checks

* Formatting

* Fix big in adjusting shape for gemm

* Formatting

* Fix fallback dot operator

* Zero initialize buffers

* Add suport for group convolutions

* Formatting

* Make adjust allocation target independent

* Formatting

* Enable adjust_allocation for gpu/cpu

* Formatting

* Add copy to allocation model

* Formatting

* Add copy operator

* Formatting

* Better handling of output parameters in adjust_allocation

* Formatting

* Build with dnnl

* Make dnnl required

* Fix compile error

* Tidy fixes

* Formatting

* Tidy fixes

* Formatting

* Fix more tidy issues

* Formatting

* Add mul op

* Add mul op

* Set c compiler to clang as well

* Compensate for normalized compute shape

* Formatting

* Fix cppcheck errors

* Formatting

* Add onednn library to hcc

* Guard clang pragmas

* Disable cpu mode for gcc for now

* Leave it enabled it for gcc 7

* Fix cppcheck suppresion

* Fix compile error on gcc 5

* Remove unused code
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 8698cd2c
...@@ -138,7 +138,7 @@ rocmtest format: rocmnode('rocmtest') { cmake_build -> ...@@ -138,7 +138,7 @@ rocmtest format: rocmnode('rocmtest') { cmake_build ->
stage('GCC 5 Release') { stage('GCC 5 Release') {
cmake_build("g++-5", "-DCMAKE_BUILD_TYPE=release") cmake_build("g++-5", "-DCMAKE_BUILD_TYPE=release")
} }
}, gcc7: rocmnode('rocmtest') { cmake_build -> }, gcc7: rocmhipclangnode('rocmtest') { cmake_build ->
stage('GCC 7 Debug') { stage('GCC 7 Debug') {
def linker_flags = '-fuse-ld=gold' def linker_flags = '-fuse-ld=gold'
def cmake_linker_flags = "-DCMAKE_EXE_LINKER_FLAGS='${linker_flags}' -DCMAKE_SHARED_LINKER_FLAGS='${linker_flags}'" def cmake_linker_flags = "-DCMAKE_EXE_LINKER_FLAGS='${linker_flags}' -DCMAKE_SHARED_LINKER_FLAGS='${linker_flags}'"
...@@ -153,7 +153,7 @@ rocmtest format: rocmnode('rocmtest') { cmake_build -> ...@@ -153,7 +153,7 @@ rocmtest format: rocmnode('rocmtest') { cmake_build ->
def linker_flags = '-fuse-ld=gold' def linker_flags = '-fuse-ld=gold'
def cmake_linker_flags = "-DCMAKE_EXE_LINKER_FLAGS='${linker_flags}' -DCMAKE_SHARED_LINKER_FLAGS='${linker_flags}'" def cmake_linker_flags = "-DCMAKE_EXE_LINKER_FLAGS='${linker_flags}' -DCMAKE_SHARED_LINKER_FLAGS='${linker_flags}'"
def debug_flags = "-g -fprofile-arcs -ftest-coverage -fno-omit-frame-pointer" def debug_flags = "-g -fprofile-arcs -ftest-coverage -fno-omit-frame-pointer"
cmake_build("g++-7", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_CPU=On -DMIGRAPHX_ENABLE_PYTHON=Off ${cmake_linker_flags} -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") cmake_build("g++-7", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_CPU=Off -DMIGRAPHX_ENABLE_PYTHON=Off ${cmake_linker_flags} -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
} }
stage('Codecov') { stage('Codecov') {
......
...@@ -65,6 +65,7 @@ RUN cget -p $PREFIX ignore \ ...@@ -65,6 +65,7 @@ RUN cget -p $PREFIX ignore \
ROCmSoftwarePlatform/MIOpen \ ROCmSoftwarePlatform/MIOpen \
ROCmSoftwarePlatform/MIOpenGEMM \ ROCmSoftwarePlatform/MIOpenGEMM \
ROCmSoftwarePlatform/rocBLAS ROCmSoftwarePlatform/rocBLAS
RUN cget -p $PREFIX init --cxx /opt/rocm/llvm/bin/clang++ RUN cget -p $PREFIX init --cxx /opt/rocm/llvm/bin/clang++ --cc /opt/rocm/llvm/bin/clang
RUN cget -p $PREFIX install -f dev-requirements.txt RUN cget -p $PREFIX install -f dev-requirements.txt
RUN cget -p $PREFIX install oneapi-src/oneDNN@v1.7
...@@ -5,6 +5,7 @@ include(RegisterOp) ...@@ -5,6 +5,7 @@ include(RegisterOp)
include(CheckCXXLinkerFlag) include(CheckCXXLinkerFlag)
add_library(migraphx add_library(migraphx
adjust_allocation.cpp
analyze_streams.cpp analyze_streams.cpp
auto_contiguous.cpp auto_contiguous.cpp
eliminate_common_subexpression.cpp eliminate_common_subexpression.cpp
......
#include <migraphx/gpu/adjust_allocation.hpp> #include <migraphx/adjust_allocation.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void adjust_allocation::apply(module& p) const void adjust_allocation::apply(module& p) const
{ {
...@@ -15,23 +16,32 @@ void adjust_allocation::apply(module& p) const ...@@ -15,23 +16,32 @@ void adjust_allocation::apply(module& p) const
if(ins->inputs().empty()) if(ins->inputs().empty())
continue; continue;
if(ins->name() == "load") // Skip target-independent operators
if(ins->get_operator().is_context_free())
continue; continue;
auto alias_ins = instruction::get_output_alias(ins, true); auto alias_ins = instruction::get_output_alias(ins, true);
if(alias_ins->name() == "hip::allocate") if(alias_ins->name() != model.name() and alias_ins->name() != "@param")
continue;
// shape allocated is different from actual shape
// of the instruction, reallocate and replace the previous one
if(alias_ins->get_shape() == ins->get_shape())
continue;
auto alloc_ins = p.insert_instruction(ins, model.allocate(ins->get_shape()));
p.replace_instruction(alias_ins, alloc_ins);
// If the memory is an output parameter then copy the memory to the parameter
if(alias_ins->name() == "@param")
{ {
// shape allocated is different from actual shape auto copy = p.insert_instruction(std::next(ins), make_op(model.copy()), ins, alias_ins);
// of the instruction, reallocate and replace the previous one auto tail = range(std::next(copy), p.end());
if(alias_ins->get_shape() != ins->get_shape()) for(auto i : iterator_for(tail))
{ {
auto alloc_ins = p.insert_instruction(ins, hip_allocate{ins->get_shape()}); if(contains(i->inputs(), ins))
p.replace_instruction(alias_ins, alloc_ins); instruction::replace_argument(i, ins, copy);
} }
} }
} }
} }
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -7,12 +7,8 @@ ...@@ -7,12 +7,8 @@
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace { namespace {
...@@ -28,11 +24,20 @@ struct find_dot_add ...@@ -28,11 +24,20 @@ struct find_dot_add
not contains({shape::float_type, shape::half_type, shape::double_type}, not contains({shape::float_type, shape::half_type, shape::double_type},
ins->get_shape().type())) ins->get_shape().type()))
return; return;
auto dot_ins = p.insert_instruction(ins, auto a_ins = ins->inputs()[0];
make_op("dot", {{"alpha", dot.alpha}, {"beta", 0}}), auto b_ins = ins->inputs()[1];
ins->inputs()[0], if(not float_equal(dot.alpha, 1))
ins->inputs()[1]); {
auto c_ins = ins->inputs()[2]; auto alpha = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
auto dot_ins = p.insert_instruction(ins, make_op("dot", {{"beta", 0}}), a_ins, b_ins);
auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1)) if(not float_equal(dot.beta, 1))
{ {
auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}}); auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}});
...@@ -44,9 +49,32 @@ struct find_dot_add ...@@ -44,9 +49,32 @@ struct find_dot_add
} }
}; };
struct find_dot_alpha
{
auto matcher() const { return match::name("dot")(match::nargs(2)); }
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
p.replace_instruction(ins, make_op("dot", {{"beta", 0}}), a_ins, b_ins);
}
};
} // namespace } // namespace
void decompose::apply(module& p) const { match::find_matches(p, find_dot_add{}); } void decompose::apply(module& p) const { match::find_matches(p, find_dot_add{}, find_dot_alpha{}); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
File mode changed from 100644 to 100755
#ifndef MIGRAPHX_GUARD_RTGLIB_ADJUST_ALLOCATION_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_ADJUST_ALLOCATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_ADJUST_ALLOCATION_HPP #define MIGRAPHX_GUARD_RTGLIB_ADJUST_ALLOCATION_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/allocation_model.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { struct program;
using module = program;
struct adjust_allocation struct adjust_allocation
{ {
std::string name() const { return "gpu::adjust_allocation"; } allocation_model model;
std::string name() const { return "adjust_allocation"; }
void apply(module& p) const; void apply(module& p) const;
}; };
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_ALLOCATION_MODEL_HPP
#define MIGRAPHX_GUARD_ALLOCATION_MODEL_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN
/// An interface for target-dependent allocation
struct allocation_model
{
/// A name of the target-dependent allocate operator
std::string name() const;
/// A name of the target-dependent copy operator
std::string copy() const;
/// Create an allocation operator for the given shape
operation allocate(const shape& s) const;
};
#else
/*
* Type-erased interface for:
*
* struct allocation_model
* {
* std::string name() const;
* std::string copy() const;
* operation allocate(const shape& s) const;
* };
*
*/
struct allocation_model
{
// Constructors
allocation_model() = default;
template <typename PrivateDetailTypeErasedT>
allocation_model(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
allocation_model& operator=(PrivateDetailTypeErasedT value)
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
allocation_model rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
std::string copy() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().copy();
}
operation allocate(const shape& s) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().allocate(s);
}
friend bool is_shared(const allocation_model& private_detail_x,
const allocation_model& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual std::string copy() const = 0;
virtual operation allocate(const shape& s) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
std::string name() const override { return private_detail_te_value.name(); }
std::string copy() const override { return private_detail_te_value.copy(); }
operation allocate(const shape& s) const override
{
return private_detail_te_value.allocate(s);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const allocation_model* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(allocation_model* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(allocation_model& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const allocation_model& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/raw_data.hpp> #include <migraphx/raw_data.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/make_shared_array.hpp>
#include <functional> #include <functional>
#include <utility> #include <utility>
...@@ -23,8 +24,8 @@ struct argument : raw_data<argument> ...@@ -23,8 +24,8 @@ struct argument : raw_data<argument>
argument(const shape& s) : m_shape(s) argument(const shape& s) : m_shape(s)
{ {
auto buffer = std::make_shared<std::vector<char>>(s.bytes()); auto buffer = make_shared_array<char>(s.bytes());
data = [=]() mutable { return buffer->data(); }; data = [=]() mutable { return buffer.get(); };
} }
template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})> template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})>
......
...@@ -172,6 +172,13 @@ struct check_shapes ...@@ -172,6 +172,13 @@ struct check_shapes
return *this; return *this;
} }
const check_shapes& batch_not_transposed() const
{
if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
MIGRAPHX_THROW(prefix() + "Batch size is transposed");
return *this;
}
template <class F> template <class F>
bool same(F f) const bool same(F f) const
{ {
...@@ -207,6 +214,28 @@ struct check_shapes ...@@ -207,6 +214,28 @@ struct check_shapes
check_shapes slice(long start) const { return {get(start), end, name}; } check_shapes slice(long start) const { return {get(start), end, name}; }
check_shapes slice(long start, long last) const { return {get(start), get(last), name}; } check_shapes slice(long start, long last) const { return {get(start), get(last), name}; }
private:
static bool batch_not_transposed_strides(const std::vector<std::size_t>& strides)
{
if(strides.size() <= 2)
return true;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
{
return false;
}
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
return false;
}
return true;
}
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program;
#ifdef DOXYGEN #ifdef DOXYGEN
/// An interface for target-dependent optimization for the concat instruction /// An interface for target-dependent optimization for the concat instruction
......
...@@ -68,8 +68,8 @@ struct literal : raw_data<literal> ...@@ -68,8 +68,8 @@ struct literal : raw_data<literal>
/// Convert the data to an argument /// Convert the data to an argument
argument get_argument() const argument get_argument() const
{ {
std::vector<char> b(buffer.get(), buffer.get() + m_shape.bytes()); auto b = make_shared_array<char>(buffer.get(), buffer.get() + m_shape.bytes());
return {m_shape, [b]() mutable { return b.data(); }}; return {m_shape, [b]() { return b.get(); }};
} }
private: private:
......
...@@ -10,7 +10,15 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,15 @@ inline namespace MIGRAPHX_INLINE_NS {
template <typename T> template <typename T>
std::shared_ptr<T> make_shared_array(size_t size) std::shared_ptr<T> make_shared_array(size_t size)
{ {
return std::shared_ptr<T>(new T[size], std::default_delete<T[]>()); // NOLINT return std::shared_ptr<T>(new T[size](), std::default_delete<T[]>()); // NOLINT
}
template <class T, class Iterator>
std::shared_ptr<T> make_shared_array(Iterator start, Iterator last)
{
auto result = make_shared_array<T>(std::distance(start, last));
std::copy(start, last, result.get());
return result;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -44,6 +44,11 @@ struct contiguous ...@@ -44,6 +44,11 @@ struct contiguous
}); });
return result; return result;
} }
auto apply() const
{
return [](auto x) { return x; };
}
}; };
} // namespace op } // namespace op
......
...@@ -822,6 +822,25 @@ inline const ValueType& any_cast(const operation& x) ...@@ -822,6 +822,25 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); } inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
inline shape compute_shape(const operation& op, const std::vector<shape>& inputs)
{
return op.compute_shape(inputs);
}
template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
-> decltype(op.compute_shape(inputs))
{
return op.compute_shape(inputs);
}
template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
-> decltype(op.normalize_compute_shape(inputs))
{
return detail::normalize_compute_shape_op(op, inputs);
}
inline bool is_context_free(const operation& op) { return op.is_context_free(); } inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T> template <class T>
......
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#include <vector> #include <vector>
#include <initializer_list> #include <initializer_list>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -34,6 +37,33 @@ auto generic_find_impl(rank<0>, C&& c, const T& x) ...@@ -34,6 +37,33 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x); return std::find(c.begin(), c.end(), x);
} }
template <class C, class T>
auto generic_find_at_impl(rank<1>, C&& c, const T& x) -> decltype(c.find(x))
{
return c.find(x);
}
template <class C, class T>
auto generic_find_at_impl(rank<0>, C&& c, const T& x)
{
auto n = std::distance(c.begin(), c.end());
if(x >= n)
return c.end();
return std::next(c.begin(), x);
}
template <class C, class T, class = typename C::mapped_type>
decltype(auto) generic_at_impl(rank<1>, const C&, T&& it)
{
return it->second;
}
template <class C, class T>
decltype(auto) generic_at_impl(rank<0>, const C&, T&& it)
{
return *it;
}
struct empty struct empty
{ {
}; };
...@@ -46,6 +76,20 @@ auto generic_find(C&& c, const T& x) ...@@ -46,6 +76,20 @@ auto generic_find(C&& c, const T& x)
return detail::generic_find_impl(rank<2>{}, c, x); return detail::generic_find_impl(rank<2>{}, c, x);
} }
template <class C, class T>
decltype(auto) at(C&& c, const T& x, const std::string& msg = "")
{
auto it = detail::generic_find_at_impl(rank<2>{}, c, x);
if(it == c.end())
{
if(msg.empty())
MIGRAPHX_THROW("At operator out of range for " + get_type_name(c));
else
MIGRAPHX_THROW(msg);
}
return detail::generic_at_impl(rank<2>{}, c, it);
}
template <class C, class T> template <class C, class T>
bool contains(const C& c, const T& x) bool contains(const C& c, const T& x)
{ {
......
...@@ -12,6 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,6 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void register_op(const operation& op); void register_op(const operation& op);
operation load_op(const std::string& name); operation load_op(const std::string& name);
bool has_op(const std::string& name);
std::vector<std::string> get_operators(); std::vector<std::string> get_operators();
template <class T> template <class T>
......
...@@ -102,6 +102,7 @@ struct shape ...@@ -102,6 +102,7 @@ struct shape
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const; std::vector<std::size_t> multi(std::size_t i) const;
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed with no padding
bool packed() const; bool packed() const;
......
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map> #include <unordered_map>
namespace migraphx { namespace migraphx {
...@@ -12,10 +13,7 @@ std::unordered_map<std::string, operation>& op_map() ...@@ -12,10 +13,7 @@ std::unordered_map<std::string, operation>& op_map()
void register_op(const operation& op) { op_map()[op.name()] = op; } void register_op(const operation& op) { op_map()[op.name()] = op; }
operation load_op(const std::string& name) operation load_op(const std::string& name)
{ {
auto it = op_map().find(name); return at(op_map(), name, "Operator not found: " + name);
if(it == op_map().end())
MIGRAPHX_THROW("Operator not found: " + name);
return it->second;
} }
std::vector<std::string> get_operators() std::vector<std::string> get_operators()
......
...@@ -146,16 +146,24 @@ std::vector<std::size_t> shape::multi(std::size_t i) const ...@@ -146,16 +146,24 @@ std::vector<std::size_t> shape::multi(std::size_t i) const
assert(this->standard()); assert(this->standard());
std::vector<std::size_t> indices(lens().size()); std::vector<std::size_t> indices(lens().size());
multi_copy(i, indices.data(), indices.data() + lens().size());
return indices;
}
void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const
{
assert(this->standard());
(void)end;
assert(lens().size() <= (end - start));
std::transform(strides().begin(), std::transform(strides().begin(),
strides().end(), strides().end(),
lens().begin(), lens().begin(),
indices.begin(), start,
[&](std::size_t stride, std::size_t len) { [&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0); assert(len > 0 and stride > 0);
return (i / stride) % len; return (i / stride) % len;
}); });
return indices;
} }
bool shape::packed() const { return this->elements() == this->element_space(); } bool shape::packed() const { return this->elements() == this->element_space(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment