Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
......@@ -2,13 +2,13 @@
#define MIGRAPHX_GUARD_OPERATORS_DOT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <cmath>
#include <utility>
......@@ -18,19 +18,10 @@ namespace op {
struct dot
{
float alpha = 1.0;
float beta = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_type();
check_shapes{inputs, *this}.same_type().has(2);
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
......@@ -58,15 +49,16 @@ struct dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
return {t, out_lens};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result = argument{output_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result;
}
};
} // namespace op
......
......@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_ELU_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......@@ -19,7 +18,7 @@ namespace op {
struct elu
{
std::string name() const { return "elu"; }
float alpha;
float alpha = 1;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
......
#ifndef MIGRAPHX_GUARD_OPERATORS_EQUAL_HPP
#define MIGRAPHX_GUARD_OPERATORS_EQUAL_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct equal : binary<equal>
{
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
std::string point_function() const { return "=="; }
auto apply() const
{
return [](auto x, auto y) { return float_equal(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -2,13 +2,15 @@
#define MIGRAPHX_GUARD_OPERATORS_FLATTEN_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -18,7 +20,7 @@ namespace op {
struct flatten
{
uint64_t axis = 0;
int64_t axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -26,16 +28,19 @@ struct flatten
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] =
value::array{normalize_attribute::include_min, normalize_attribute::include_max};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "flatten"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1).standard();
auto&& lens = inputs.front().lens();
if(axis > lens.size())
{
MIGRAPHX_THROW("axis for flatten must be less than tensor rank");
}
auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y =
......@@ -44,7 +49,7 @@ struct flatten
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
return args[0].reshape(output_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -2,13 +2,14 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -18,7 +19,7 @@ namespace op {
struct gather
{
int axis = 0;
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -26,27 +27,25 @@ struct gather
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
{
MIGRAPHX_THROW("Gather: axis is out of range.");
}
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type();
lens.erase(lens.begin() + axis_index);
lens.erase(lens.begin() + axis);
if(!inputs[1].scalar())
{
auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
}
// for scalar output
......@@ -62,10 +61,8 @@ struct gather
{
argument result{output_shape};
// negative axis means counting dimensions from back
auto lens = args[0].get_shape().lens();
int axis_index = (axis < 0) ? static_cast<int>(lens.size() + axis) : axis;
std::size_t axis_dim_size = lens[axis_index];
auto lens = args[0].get_shape().lens();
std::size_t axis_dim_size = lens[axis];
// max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
......@@ -73,18 +70,18 @@ struct gather
{
auto in_index = indices.front();
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
output[0] = data[indices.front()];
output[0] = data[in_index];
}
else
{
auto out_lens = data.get_shape().lens();
out_lens[axis_index] = indices.get_shape().elements();
auto out_lens = data.get_shape().lens();
out_lens[axis] = indices.get_shape().elements();
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) {
auto data_idx = out_idx;
auto in_index = indices[data_idx[axis_index]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
data_idx[axis_index] = in_index;
auto data_idx = out_idx;
auto in_index = indices[data_idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
data_idx[axis] = in_index;
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
data(data_idx.begin(), data_idx.end());
});
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct gathernd
{
int batch_dims = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.batch_dims, "batch_dims"));
}
std::string name() const { return "gathernd"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto r = inputs.front().lens().size();
auto q = inputs.back().lens().size();
auto k = inputs.back().lens().back();
if(k > r - batch_dims)
{
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " +
std::to_string(r - batch_dims));
}
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1;
std::vector<std::size_t> output_lens(output_lens_size);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
if(k < r - batch_dims)
{
auto data_lens = inputs.front().lens();
std::copy(
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1);
}
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape();
auto indices_shape_lens = indices_shape.lens();
auto data_shape = data.get_shape();
auto data_shape_lens = data_shape.lens();
auto k = indices_shape.lens().back();
const auto num_slice_dims = k;
std::size_t num_slices = std::accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = std::accumulate(data_shape_lens.begin() + k + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
std::size_t num_batches = std::accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
std::size_t data_batch_stride =
std::accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
auto num_slices_per_batch = num_slices / num_batches;
std::vector<std::size_t> sizes_from_slice_dims(num_slice_dims);
{
auto running_product = slice_size;
for(std::size_t i = 0; i < num_slice_dims; ++i)
{
sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product;
running_product *= data_shape_lens[batch_dims + num_slice_dims - 1 - i];
}
}
std::vector<std::size_t> input_slice_offsets(num_slices);
par_for(num_slices, [&](const auto i) {
std::size_t batch_idx = i / num_slices_per_batch;
auto slice_indices = indices.begin() + (i * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx)
{
int64_t index = *(slice_indices + dim_idx);
const std::size_t input_dim_idx = batch_dims + dim_idx;
const auto input_dim = data_shape_lens[input_dim_idx];
if(index < -static_cast<int64_t>(input_dim) or
index >= static_cast<int64_t>(input_dim))
MIGRAPHX_THROW("GatherND: index " + std::to_string(index) +
" is out of bounds for dim of len " +
std::to_string(input_dim));
if(index < 0)
index += input_dim;
relative_slice_offset += index * sizes_from_slice_dims[dim_idx];
}
input_slice_offsets[i] =
(batch_idx * data_batch_stride) + relative_slice_offset;
});
par_for(num_slices * slice_size, [&](const auto i) {
auto slice_offset = input_slice_offsets[i / slice_size];
output[i] = data[slice_offset + i % slice_size];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#define MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#include "migraphx/errors.hpp"
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct get_tuple_elem
{
std::size_t index = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.index, "index"));
}
std::string name() const { return "get_tuple_elem"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).tuple_type();
const auto& sub_shapes = inputs.at(0).sub_shapes();
if(index >= sub_shapes.size())
{
MIGRAPHX_THROW("GET_TUPLE_ELEM: index " + std::to_string(index) + " is out of range " +
std::to_string(sub_shapes.size()));
}
return sub_shapes.at(index);
}
argument compute(const shape&, std::vector<argument> args) const
{
assert(args.size() == 1);
auto vec_args = args.at(0).get_sub_objects();
assert(index < vec_args.size());
return vec_args.at(index);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_GREATER_HPP
#define MIGRAPHX_GUARD_OPERATORS_GREATER_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct greater : binary<greater>
{
std::string point_function() const { return ">"; }
auto apply() const
{
return [](auto x, auto y) { return x > y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -3,9 +3,9 @@
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/op/tanh.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_IDENTITY_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......@@ -20,10 +19,8 @@ struct identity
{
std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
argument compute(shape, std::vector<argument> args) const { return args[0]; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module.hpp>
#include <cmath>
#include <utility>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct if_op
{
std::string name() const { return "if"; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
check_shapes{inputs, *this}.standard();
if(mods.size() != 2)
{
MIGRAPHX_THROW("IF: operator should have two submodules.");
}
auto out_shapes0 = mods[0]->get_output_shapes();
auto out_shapes1 = mods[1]->get_output_shapes();
if(not std::equal(
out_shapes1.begin(), out_shapes1.end(), out_shapes0.begin(), out_shapes0.end()))
{
MIGRAPHX_THROW("IF: output shapes of submodules must be the same.");
}
return {out_shapes0};
}
argument compute(const shape&,
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
auto cond = args.front().at<bool>();
module_ref mod = cond ? mods[0] : mods[1];
std::unordered_map<std::string, argument> params;
std::set<std::string> pnames;
for(const auto& smod : mods)
{
auto names = smod->get_parameter_names();
pnames.insert(names.begin(), names.end());
}
assert(pnames.size() < args.size());
std::transform(pnames.begin(),
pnames.end(),
args.begin() + 1,
std::inserter(params, params.end()),
[](auto&& name, auto&& arg) { return std::make_pair(name, arg); });
auto results = run(mod, params);
return argument{results};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -2,12 +2,8 @@
#define MIGRAPHX_GUARD_OPERATORS_IM2COL_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
......@@ -18,9 +14,9 @@ namespace op {
struct im2col
{
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
std::vector<std::size_t> padding{0, 0};
std::vector<std::size_t> stride{1, 1};
std::vector<std::size_t> dilation{1, 1};
padding_mode_t padding_mode = default_;
......@@ -35,7 +31,9 @@ struct im2col
std::string name() const { return "im2col"; }
shape compute_shape(std::vector<shape> inputs) const
value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
auto input = inputs[0];
auto weights = inputs[1];
......@@ -46,17 +44,24 @@ struct im2col
check_shapes{inputs, *this}.has(2);
if(batch_size != 1)
MIGRAPHX_THROW("im2col only support batch_size 1");
auto padding_h = 2 * padding[0];
auto padding_w = 2 * padding[1];
if(padding.size() == 2 * stride.size())
{
padding_h = padding[0] + padding[2];
padding_w = padding[1] + padding[3];
}
auto output_height = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
stride[0] +
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + padding_h) / stride[0] +
1));
auto output_width = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
stride[1] +
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + padding_w) / stride[1] +
1));
auto channels_col = kernel_height * kernel_width * input_channels;
auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}};
}
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct isnan : unary<isnan>
{
auto apply() const
{
return [](auto x) { return std::isnan(x); };
}
std::string name() const { return "isnan"; }
shape compute_shape(std::vector<shape> inputs) const
{
return unary<isnan>::compute_shape(std::move(inputs)).with_type(shape::bool_type);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_LEAKY_RELU_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......@@ -18,7 +17,7 @@ namespace op {
struct leaky_relu
{
float alpha;
float alpha = 0.01;
template <class Self, class F>
static auto reflect(Self& self, F f)
......
#ifndef MIGRAPHX_GUARD_OPERATORS_LESS_HPP
#define MIGRAPHX_GUARD_OPERATORS_LESS_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct less : binary<less>
{
std::string point_function() const { return "<"; }
auto apply() const
{
return [](auto x, auto y) { return x < y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -2,13 +2,11 @@
#define MIGRAPHX_GUARD_OPERATORS_LOAD_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -30,15 +28,16 @@ struct load
std::string name() const { return "load"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
return s;
}
argument compute(const shape&, const std::vector<argument>& args) const
{
if((offset + s.bytes()) > args[0].get_shape().bytes())
MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset};
return argument{s, args[0].data() + offset};
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op)
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGICAL_AND_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGICAL_AND_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct logical_and : binary<logical_and>
{
std::string point_function() const { return "&&"; }
auto apply() const
{
return [](auto x, auto y) { return static_cast<bool>(x) and static_cast<bool>(y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGICAL_OR_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGICAL_OR_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct logical_or : binary<logical_or>
{
std::string point_function() const { return "||"; }
auto apply() const
{
return [](auto x, auto y) { return static_cast<bool>(x) or static_cast<bool>(y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment