Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
......@@ -89,6 +89,13 @@ inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x,
return x.index - y.index;
}
template <class F, class Iterator>
inline basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x,
std::ptrdiff_t y)
{
return x -= y;
}
template <class F, class Iterator>
inline bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
......
......@@ -8,6 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::string to_pretty_json_string(const value& val, std::size_t indent = 4);
std::string to_json_string(const value& val);
value from_json_string(const std::string& str);
value from_json_string(const char* str, std::size_t size);
......
......@@ -9,7 +9,19 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v);
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template <class Value>
operation make_op(const std::string& name, const Value& v)
{
return make_op_from_value(name, v);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_MARKER_HPP
#define MIGRAPHX_GUARD_MARKER_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN
/// Marker is an interface to general marking functions, such as rocTX markers.
#else
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct marker
{
//
void mark_start(instruction_ref ins_ref);
//
void mark_start(const program& prog);
//
void mark_stop(instruction_ref ins);
//
void mark_stop(const program& prog);
};
#else
struct marker
{
// Constructors
marker() = default;
template <typename PrivateDetailTypeErasedT>
marker(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
marker& operator=(PrivateDetailTypeErasedT value)
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
marker rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
void mark_start(instruction_ref ins_ref)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_start(ins_ref);
}
void mark_start(const program& prog)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_start(prog);
}
void mark_stop(instruction_ref ins)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_stop(ins);
}
void mark_stop(const program& prog)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_stop(prog);
}
friend bool is_shared(const marker& private_detail_x, const marker& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void mark_start(instruction_ref ins_ref) = 0;
virtual void mark_start(const program& prog) = 0;
virtual void mark_stop(instruction_ref ins) = 0;
virtual void mark_stop(const program& prog) = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void mark_start(instruction_ref ins_ref) override
{
private_detail_te_value.mark_start(ins_ref);
}
void mark_start(const program& prog) override { private_detail_te_value.mark_start(prog); }
void mark_stop(instruction_ref ins) override { private_detail_te_value.mark_stop(ins); }
void mark_stop(const program& prog) override { private_detail_te_value.mark_stop(prog); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const marker* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(marker* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(marker& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const marker& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -101,17 +101,17 @@ template <class M>
auto bind_match(M m, std::string name)
{
return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins)
->optional<instruction_ref> {
auto result = m.match(ctx, ins);
if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
[=, name = std::move(name)](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = m.match(ctx, ins);
if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
}
/// Convert a matcher to a bindable matcher
......@@ -156,6 +156,19 @@ struct id_matcher
}
};
// Forward declare class and constructors
template <class M>
struct basic_matcher;
template <class M>
basic_matcher<M> make_basic_matcher(M m);
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher
template <class M>
struct basic_matcher
......@@ -167,8 +180,8 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result)
{
......@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct matcher_result
{
std::unordered_map<std::string, instruction_ref> instructions;
struct instruction_container
{
instruction_container() = default;
instruction_container(std::unordered_map<std::string, instruction_ref> x)
: ins_map(std::move(x))
{
}
instruction_ref operator[](const std::string& name) const
{
auto it = ins_map.find(name);
if(it == ins_map.end())
MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
return it->second;
}
auto find(const std::string& name) const { return ins_map.find(name); }
auto begin() const { return ins_map.cbegin(); }
auto end() const { return ins_map.cend(); }
bool has_instructions_in(const module& mod) const
{
return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
return mod.has_instruction(p.second);
});
}
private:
std::unordered_map<std::string, instruction_ref> ins_map;
};
instruction_container instructions;
instruction_ref result;
};
......@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
result.result = ins;
result.instructions = ctx.instructions;
assert(result.instructions.has_instructions_in(mod));
}
else
{
......@@ -263,6 +309,20 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
return result;
}
/// Find first instance of a matching instruction in a module
template <class M>
match::matcher_result find_match(module& modl, M&& m)
{
match::matcher_result result;
for(auto ins : iterator_for(modl))
{
result = match::match_instruction(modl, ins, m);
if(result.result != modl.end())
return result;
}
return result;
}
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the module
......@@ -519,10 +579,22 @@ auto skip_output(Ms... ms)
});
}
inline auto var(std::string s)
{
return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s);
if(it == ctx.instructions.end())
return nullopt;
return it->second;
});
}
inline auto name(std::string s)
{
return make_basic_pred_matcher(
[ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
[=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; });
}
inline auto name_contains(const std::string& name)
......@@ -533,7 +605,7 @@ inline auto name_contains(const std::string& name)
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
}
......@@ -682,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...);
}
template <class... Ms>
auto skip_broadcasts_converts(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
return skip_broadcasts(make_basic_pred_matcher([=](instruction_ref ins) {
return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "@literal")
return false;
auto l = ins->get_literal();
......
......@@ -17,7 +17,7 @@ struct memory_coloring
std::string allocation_op{};
bool verify = false;
std::string name() const { return "memory coloring"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -46,6 +46,9 @@ struct module
std::string name() const;
bool bypass() const;
void set_bypass(bool b = true);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref add_instruction(operation op, Ts... args)
{
......@@ -93,6 +96,11 @@ struct module
instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
{
......@@ -107,6 +115,8 @@ struct module
instruction_ref add_return(std::vector<instruction_ref> args);
instruction_ref replace_return(std::vector<instruction_ref> args);
std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const;
......
......@@ -18,6 +18,8 @@ struct onnx_options
bool skip_unknown_operators = false;
/// Print program if an error occurs
bool print_program_on_error = false;
/// Max iter num for the loop operator
int64_t max_loop_iterations = 10;
};
/// Create a program from an onnx file
......@@ -29,6 +31,8 @@ program parse_onnx_buffer(const std::string& buffer, const onnx_options& options
/// Create a program from an onnx buffer
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options);
std::vector<std::string> get_onnx_operators();
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -35,7 +35,7 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens();
lens[axis] = 1;
......
......@@ -35,7 +35,7 @@ struct argmin
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens();
lens[axis] = 1;
......
......@@ -36,7 +36,6 @@ struct as_shape
{
return args.front().reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -30,7 +30,7 @@ struct broadcast
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims"));
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens"));
}
std::string name() const { return "broadcast"; }
......@@ -67,7 +67,6 @@ struct broadcast
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <cmath>
#include <utility>
......@@ -29,7 +30,9 @@ struct capture
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const
// the context argument is added to prevent the op from be eliminated by
// constant propagation
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
if(f)
{
......@@ -42,6 +45,8 @@ struct capture
return args.front();
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -7,7 +7,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
......@@ -21,25 +21,26 @@ struct clip
{
std::string name() const { return "clip"; }
value attributes() const
{
return {{"pointwise", true},
{"point_op", "${function:min}(${function:max}(${1}, ${0}), ${2})"}};
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_type();
check_shapes{inputs, *this}.has(3).same_type().same_dims();
return inputs.front();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); });
});
visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto input, auto min_val, auto max_val) {
auto max = max_val.front();
auto min = min_val.front();
std::transform(input.begin(), input.end(), output.begin(), [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
});
});
return result;
}
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <ostream>
#include <vector>
#include <migraphx/config.hpp>
#include <utility>
......@@ -15,6 +17,15 @@ enum padding_mode_t
valid
};
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum class pooling_mode
{
average,
max,
lpnorm
};
// indicate rnn computation direction
enum class rnn_direction
{
......@@ -23,6 +34,7 @@ enum class rnn_direction
bidirectional,
};
std::ostream& operator<<(std::ostream& os, pooling_mode v);
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op
......
......@@ -32,6 +32,11 @@ struct convert : unary<convert>
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
}
std::string point_op() const
{
return "${function:convert}<" + shape::cpp_type(target_type) + ">(${0})";
}
auto apply() const
{
auto type = target_type;
......
......@@ -9,6 +9,8 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_dfor.hpp>
#include <cmath>
#include <utility>
......@@ -70,6 +72,78 @@ struct deconvolution
return inputs[0].with_lens(output_lens);
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto kdims = this->kdims();
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), type{0});
auto in_lens = input.get_shape().lens();
auto in_n = in_lens[0];
auto in_c = in_lens[1];
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto out_lens = output_shape.lens();
std::vector<std::size_t> win_size{in_c};
std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size));
std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size));
shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0];
auto input_dims_start = idx_win.begin() + 1;
auto wei_dims_start = idx_win.begin() + kdims + 1;
std::vector<std::ptrdiff_t> win_start;
for(std::size_t n = 0; n < kdims; ++n)
{
win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * stride[n]) -
std::ptrdiff_t(padding[n]));
}
const int group_id = w / (wei_n / group);
const int in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx_out{o, in_ch};
for(size_t n = 0; n < kdims; n++)
{
idx_out.push_back(win_start[n] + *(wei_dims_start + n) * dilation[n]);
}
std::vector<std::ptrdiff_t> idx_wei{w, k};
std::copy(wei_dims_start, idx_win.end(), std::back_inserter(idx_wei));
std::vector<std::ptrdiff_t> idx_in{o, w};
std::copy(input_dims_start, wei_dims_start, std::back_inserter(idx_in));
if(std::all_of(
idx_out.begin() + 2, idx_out.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx_out.begin() + 2,
idx_out.end(),
out_lens.begin() + 2,
out_lens.end(),
std::less<std::ptrdiff_t>{}))
{
output(idx_out.begin(), idx_out.end()) +=
input(idx_in.begin(), idx_in.end()) *
weights(idx_wei.begin(), idx_wei.end());
}
});
});
});
return result;
}
size_t kdims() const
{
check_attribute_size();
......
......@@ -25,14 +25,15 @@ struct dequantizelinear
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
return {shape::float_type, inputs[0].lens(), inputs[0].strides()};
check_shapes{inputs, *this}.same_dims();
return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto x_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
std::vector<int8_t> zeros(output_shape.bytes(), 0);
argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()};
if(args.size() == 3)
{
......
......@@ -18,19 +18,10 @@ namespace op {
struct dot
{
float alpha = 1.0;
float beta = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_type();
check_shapes{inputs, *this}.same_type().has(2);
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
......@@ -58,25 +49,14 @@ struct dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
return {t, out_lens};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result;
if(args.size() == 3)
result = args[2];
else
result = argument{output_shape};
argument result = argument{output_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, alpha, beta); });
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result;
}
};
......
......@@ -51,7 +51,6 @@ struct flatten
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment