Commit 5d057776 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents d6b4ae77 9b19b73f
#include <migraphx/op/common.hpp>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -28,10 +29,14 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -28,10 +29,14 @@ struct parse_roialign : op_parser<parse_roialign>
"\": invalid value!"); "\": invalid value!");
} }
std::string mode = "avg"; migraphx::op::pooling_mode rmode(migraphx::op::pooling_mode::average);
if(contains(info.attributes, "mode")) if(contains(info.attributes, "mode"))
{ {
mode = info.attributes.at("mode").s(); // read mode; default is "avg"
if(info.attributes.at("mode").s() == "max")
{
rmode = migraphx::op::pooling_mode::max;
}
} }
int64_t output_height = 1; int64_t output_height = 1;
...@@ -57,10 +62,9 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -57,10 +62,9 @@ struct parse_roialign : op_parser<parse_roialign>
{ {
spatial_scale = info.attributes.at("spatial_scale").f(); spatial_scale = info.attributes.at("spatial_scale").f();
} }
return info.add_instruction(make_op("roialign", return info.add_instruction(make_op("roialign",
{{"coordinate_transformation_mode", coord_trans_mode}, {{"coordinate_transformation_mode", coord_trans_mode},
{"mode", mode}, {"mode", rmode},
{"output_height", output_height}, {"output_height", output_height},
{"output_width", output_width}, {"output_width", output_width},
{"sampling_ratio", sampling_ratio}, {"sampling_ratio", sampling_ratio},
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_size : op_parser<parse_size>
{
std::vector<op_desc> operators() const { return {{"Size"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
return info.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type},
{args[0]->get_shape().elements()}});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
//
// Supporting functions for enum values used in operator parameters.
// These values are declared as "enum class" and should include << streaming operators
// to be able to write their values in human-readable format so users can
// save and edit model files.
//
#include <sstream>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
std::ostream& operator<<(std::ostream& os, pooling_mode v)
{
// the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf
static const std::vector<std::string> pooling_mode_str = {"average", "max"};
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
return os;
}
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
static const std::vector<std::string> rnn_direction_str = {
"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -556,8 +556,14 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -556,8 +556,14 @@ std::vector<argument> program::eval(parameter_map params) const
if(trace_level > 0) if(trace_level > 0)
{ {
auto max_ins_len = max_ins_length(); auto max_ins_len = max_ins_length();
std::unordered_map<instruction_ref, std::string> ins_names; std::unordered_map<instruction_ref, std::string> ins_out;
this->print(ins_names, [&](auto, auto) {}); // get instruction names
this->print([&](auto x, auto ins_names) {
std::stringstream ss;
instruction::print(ss, x, ins_names);
ins_out[x] = ss.str();
});
if(trace_level == 3) if(trace_level == 3)
{ {
std::string prefix = "Run instruction: "; std::string prefix = "Run instruction: ";
...@@ -571,9 +577,7 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -571,9 +577,7 @@ std::vector<argument> program::eval(parameter_map params) const
with_check_context([&](auto& ins, auto f, auto&& check_context) { with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish(); ctx.finish();
std::stringstream ss; std::stringstream ss;
ss << "Run instruction: "; ss << "Run instruction: " << ins_out.at(ins);
this->debug_print(ss, ins, ins_names);
timer t{}; timer t{};
auto result = check_context(f); auto result = check_context(f);
double t1 = t.record<milliseconds>(); double t1 = t.record<milliseconds>();
...@@ -583,7 +587,7 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -583,7 +587,7 @@ std::vector<argument> program::eval(parameter_map params) const
{ {
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
std::cout << "Time: " << t1 << "ms, " << t2 std::cout << "Time: " << t1 << "ms, " << t2
<< "ms, execution time:\t"; << "ms" << std::endl;
if(trace_level == 2 and ins->name().front() != '@' and if(trace_level == 2 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty()) ins->name() != "load" and not result.empty())
{ {
...@@ -1007,6 +1011,14 @@ void program::print( ...@@ -1007,6 +1011,14 @@ void program::print(
} }
} }
void program::print(
const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>& print_func) const
{
std::unordered_map<instruction_ref, std::string> names;
this->print(names, print_func);
}
void program::print_graph(std::ostream& os, bool brief) const void program::print_graph(std::ostream& os, bool brief) const
{ {
const auto* mm = this->get_main_module(); const auto* mm = this->get_main_module();
......
...@@ -211,12 +211,21 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -211,12 +211,21 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m) MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape>(m, "shape")
.def(py::init<>()) .def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", std::string{"float"}));
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
else
return migraphx::shape(t, lens);
}))
.def("type", &migraphx::shape::type) .def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens) .def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides) .def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements) .def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes) .def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size) .def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed) .def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed) .def("transposed", &migraphx::shape::transposed)
......
...@@ -38,7 +38,7 @@ void rewrite_pooling::apply(module& prog) const ...@@ -38,7 +38,7 @@ void rewrite_pooling::apply(module& prog) const
instruction_ref pooling{}; instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
pooling = pooling =
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape); prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
......
...@@ -1426,14 +1426,5 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog, ...@@ -1426,14 +1426,5 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
return hs_padded; return hs_padded;
} }
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -460,10 +460,10 @@ struct cpu_apply ...@@ -460,10 +460,10 @@ struct cpu_apply
if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and
not v["ceil_mode"].to<bool>()) not v["ceil_mode"].to<bool>())
return replace(ins, make_op("dnnl::pooling", op.to_value())); return replace(ins, make_op("dnnl::pooling", op.to_value()));
std::string mode = v["mode"].to<std::string>(); op::pooling_mode mode = v["mode"].to<op::pooling_mode>();
if(mode == "max") if(mode == op::pooling_mode::max)
return replace(ins, make_op("cpu::pooling_max", v)); return replace(ins, make_op("cpu::pooling_max", v));
else if(mode == "average") else if(mode == op::pooling_mode::average)
return replace(ins, make_op("cpu::pooling_average", v)); return replace(ins, make_op("cpu::pooling_average", v));
return ins; return ins;
} }
......
...@@ -129,7 +129,8 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po ...@@ -129,7 +129,8 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg; auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto kdims = op.kdims(); auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end()); std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
...@@ -145,5 +146,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po ...@@ -145,5 +146,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
}; };
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -240,8 +240,9 @@ std::string enum_params(std::size_t count, std::string param) ...@@ -240,8 +240,9 @@ std::string enum_params(std::size_t count, std::string param)
std::size_t compute_global(std::size_t n, std::size_t local) std::size_t compute_global(std::size_t n, std::size_t local)
{ {
std::size_t groups = (n + local - 1) / local; std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local; // max possible number of blocks is set to 1B (1,073,741,824)
std::size_t nglobal = std::min<std::size_t>(1073741824, groups) * local;
return nglobal; return nglobal;
} }
......
...@@ -59,8 +59,8 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const ...@@ -59,8 +59,8 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const
// pooling_mode // pooling_mode
assert(val.contains("mode")); assert(val.contains("mode"));
auto mode = val.at("mode").to<std::string>(); auto mode = val.at("mode").to<migraphx::op::pooling_mode>();
bool is_avg_pooling = (mode == "avg"); bool is_avg_pooling = (mode == migraphx::op::pooling_mode::average);
options.params += " -DIS_AVG_POOLING=" + std::to_string(static_cast<int>(is_avg_pooling)); options.params += " -DIS_AVG_POOLING=" + std::to_string(static_cast<int>(is_avg_pooling));
// coord_trans_mode // coord_trans_mode
......
...@@ -59,8 +59,8 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal) ...@@ -59,8 +59,8 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
assert(s.elements() > 0); assert(s.elements() > 0);
index_int n = s.elements(); index_int n = s.elements();
index_int groups = (n + nlocal - 1) / nlocal; index_int groups = (n + nlocal - 1) / nlocal;
// change the max group num to 1 Million // max possible number of blocks is set to 1B (1,073,741,824)
index_int nglobal = std::min<index_int>((1 << 20), groups) * nlocal; index_int nglobal = std::min<index_int>(1073741824, groups) * nlocal;
assert(groups > 0); assert(groups > 0);
assert(nglobal > 0); assert(nglobal > 0);
......
...@@ -53,6 +53,12 @@ __device__ void block_scan(index idx, Op op, T init, index_int n, Input input, O ...@@ -53,6 +53,12 @@ __device__ void block_scan(index idx, Op op, T init, index_int n, Input input, O
output); output);
} }
template <class F>
constexpr auto reverse_scan(index_int n, F f)
{
return [=](auto i, auto&&... xs) { return f(n - i - 1, xs...); };
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/gpu/device/prefix_scan_sum.hpp> #include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/gpu/device/scan.hpp> #include <migraphx/gpu/device/scan.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp> #include <migraphx/gpu/device/reduce_ops.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
namespace migraphx { namespace migraphx {
...@@ -8,30 +9,108 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -8,30 +9,108 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& arg, int32_t axis) void prefix_scan_sum(hipStream_t stream,
const argument& result,
const argument& arg,
int32_t axis,
bool exclusive,
bool reverse)
{ {
const index_int block_size = 256; const index_int max_block_size = 256;
const index_int n = arg.get_shape().lens()[axis]; const index_int n = arg.get_shape().lens()[axis];
auto rlens = result.get_shape().lens(); auto rlens = result.get_shape().lens();
rlens[axis] = 1; rlens[axis] = 1;
hip_visit_all(result, arg, result.get_shape().with_lens(rlens))( hip_visit_all(result, arg, result.get_shape().with_lens(rlens))(
[=](auto output, auto input, auto rshape) { [=](auto output, auto input, auto rshape) {
gs_launch(stream, rshape.elements() * block_size, block_size)( const index_int block_size = compute_block_size(rshape.elements(), max_block_size);
[=](auto i, auto idx) __device__ { if(reverse and exclusive)
const auto ridx = rshape.multi(i / block_size); {
auto compute_idx = [&](auto j) { gs_launch(stream, rshape.elements() * block_size, block_size)(
auto k = ridx; [=](auto i, auto idx) __device__ {
k[axis] = j; const auto ridx = rshape.multi(i / block_size);
return k; auto compute_idx = [&](auto j) {
}; auto k = ridx;
block_scan<block_size>( k[axis] = j;
idx, return k;
sum{}, };
0, block_scan<max_block_size>(
n, idx,
[&](auto j) { return input[compute_idx(j)]; }, sum{},
[&](auto j, auto x) { output[compute_idx(j)] = x; }); 0,
}); n,
reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }),
reverse_scan(n, [&](auto j, auto x) {
if(j == n - 1)
output[compute_idx(j)] = 0;
if(j > 0)
output[compute_idx(j - 1)] = x;
}));
});
}
else if(reverse)
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }),
reverse_scan(n, [&](auto j, auto x) { output[compute_idx(j)] = x; }));
});
}
else if(exclusive)
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
[&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) {
auto k = j + 1;
if(j == 0)
output[compute_idx(0)] = 0;
if(k < n)
output[compute_idx(k)] = x;
});
});
}
else
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
[&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) { output[compute_idx(j)] = x; });
});
}
}); });
} }
......
...@@ -10,7 +10,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& arg, int32_t axis); void prefix_scan_sum(hipStream_t stream,
const argument& result,
const argument& arg,
int32_t axis,
bool exclusive,
bool reverse);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <miopen/miopen.h> #include <miopen/miopen.h>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <sstream>
#ifdef HAS_FIND_MODE_API #ifdef HAS_FIND_MODE_API
extern "C" miopenStatus_t miopenHiddenSetConvolutionFindMode(miopenConvolutionDescriptor_t convDesc, extern "C" miopenStatus_t miopenHiddenSetConvolutionFindMode(miopenConvolutionDescriptor_t convDesc,
int findMode); int findMode);
...@@ -132,12 +134,16 @@ inline convolution_descriptor make_deconv(const T& op) ...@@ -132,12 +134,16 @@ inline convolution_descriptor make_deconv(const T& op)
inline pooling_descriptor make_pooling(const migraphx::op::pooling& op) inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
{ {
miopenPoolingMode_t mode; miopenPoolingMode_t mode;
if(op.mode == "max") if(op.mode == op::pooling_mode::max)
mode = miopenPoolingMax; mode = miopenPoolingMax;
else if(op.mode == "average") else if(op.mode == op::pooling_mode::average)
mode = miopenPoolingAverage; mode = miopenPoolingAverage;
else else
MIGRAPHX_THROW("Unknown mode for pooling: " + op.mode); {
std::stringstream ss("Unknown mode for pooling: ");
ss << op.mode;
MIGRAPHX_THROW(ss.str());
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor); auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
int kdims = op.kdims(); int kdims = op.kdims();
......
...@@ -40,9 +40,8 @@ struct hip_prefix_scan_sum : oper<hip_prefix_scan_sum> ...@@ -40,9 +40,8 @@ struct hip_prefix_scan_sum : oper<hip_prefix_scan_sum>
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
if(op.exclusive or op.reverse) device::prefix_scan_sum(
MIGRAPHX_THROW("Exclusive and reverse scan not supported"); ctx.get_stream().get(), args[1], args[0], op.axis, op.exclusive, op.reverse);
device::prefix_scan_sum(ctx.get_stream().get(), args[1], args[0], op.axis);
return args[1]; return args[1];
} }
......
...@@ -70,5 +70,11 @@ using index_constant = integral_constant<index_int, N>; ...@@ -70,5 +70,11 @@ using index_constant = integral_constant<index_int, N>;
template <auto V> template <auto V>
static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT
template <class F>
constexpr auto return_c(F f)
{
return _c<f()>;
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP #endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#define MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
template <class F, class Iterator = diff_int>
struct basic_iota_iterator
{
Iterator index;
F f;
using difference_type = diff_int;
using reference = decltype(f(std::declval<Iterator>()));
using value_type = remove_reference_t<reference>;
using pointer = add_pointer_t<value_type>;
constexpr basic_iota_iterator& operator+=(diff_int n)
{
index += n;
return *this;
}
constexpr basic_iota_iterator& operator-=(diff_int n)
{
index -= n;
return *this;
}
constexpr basic_iota_iterator& operator++()
{
index++;
return *this;
}
constexpr basic_iota_iterator& operator--()
{
index--;
return *this;
}
constexpr basic_iota_iterator operator++(int) // NOLINT
{
basic_iota_iterator it = *this;
index++;
return it;
}
constexpr basic_iota_iterator operator--(int) // NOLINT
{
basic_iota_iterator it = *this;
index--;
return it;
}
// TODO: operator->
constexpr reference operator*() const { return f(index); }
template <class T>
constexpr reference operator[](T x) const
{
return f(index + x);
}
};
template <class T, class F>
constexpr basic_iota_iterator<F, T> make_basic_iota_iterator(T x, F f)
{
return basic_iota_iterator<F, T>{x, f};
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x += y;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(diff_int x, basic_iota_iterator<F, Iterator> y)
{
return y + x;
}
template <class F, class Iterator>
constexpr diff_int operator-(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index - y.index;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x -= y;
}
template <class F, class Iterator>
constexpr bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator>
constexpr bool operator!=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
template <class F, class Iterator>
constexpr bool operator<(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index < y.index;
}
template <class F, class Iterator>
constexpr bool operator>(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index > y.index;
}
template <class F, class Iterator>
constexpr bool operator>=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index >= y.index;
}
template <class F, class Iterator>
constexpr bool operator<=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index <= y.index;
}
struct defaul_iota_iterator
{
template <class T>
constexpr auto operator()(T x) const
{
return x;
}
};
using iota_iterator = basic_iota_iterator<defaul_iota_iterator>;
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
...@@ -39,10 +39,8 @@ template <class F, class T, class... Ts> ...@@ -39,10 +39,8 @@ template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
preload<typename T::type>(idx, xs...)([&](auto... ps) { preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) { idx.global_stride(out.get_shape().elements(),
auto multi_idx = out.get_shape().multi(i); [&](auto i) { out[i] = implicit_conversion(f(ps[i]...)); });
out[multi_idx] = implicit_conversion(f(ps[multi_idx]...));
});
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment