Commit 330f6db8 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into finalize

parents 333649e2 d67e7961
...@@ -94,6 +94,12 @@ constexpr void each_args(F) ...@@ -94,6 +94,12 @@ constexpr void each_args(F)
{ {
} }
template <class F, class T>
auto unpack(F f, T& x)
{
return sequence_c<std::tuple_size<T>{}>([&](auto... is) { f(std::get<is>(x)...); });
}
/// Implements a fix-point combinator /// Implements a fix-point combinator
template <class R, class F> template <class R, class F>
detail::fix_f<R, F> fix(F f) detail::fix_f<R, F> fix(F f)
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -631,6 +633,61 @@ struct as_shape ...@@ -631,6 +633,61 @@ struct as_shape
int output_alias(const std::vector<shape>&) const { return 0; } int output_alias(const std::vector<shape>&) const { return 0; }
}; };
struct gather
{
std::size_t axis = 0;
std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens();
if(axis >= lens.size())
{
MIGRAPHX_THROW("Gather, axis is out of range.");
}
auto type = inputs[0].type();
lens[axis] = inputs[1].elements();
return {type, lens};
}
template <class T>
void compute_index(const T& out_idx,
const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim,
T& in_idx) const
{
in_idx = out_idx;
std::size_t idx = vec_indices.at(out_idx[axis]);
if(idx >= max_dim)
{
MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
}
in_idx[axis] = idx;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
});
});
return result;
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct dot struct dot
{ {
float alpha = 1.0; float alpha = 1.0;
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
#include <array>
#include <numeric>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class... Ts>
auto par_dfor(Ts... xs)
{
return [=](auto f) {
using array_type = std::array<std::size_t, sizeof...(Ts)>;
array_type lens = {{static_cast<std::size_t>(xs)...}};
auto n = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>{});
const std::size_t min_grain = 8;
if(n > 2 * min_grain)
{
array_type strides;
strides.fill(1);
std::partial_sum(lens.rbegin(),
lens.rend() - 1,
strides.rbegin() + 1,
std::multiplies<std::size_t>());
auto size =
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>());
par_for(size, min_grain, [&](std::size_t i) {
array_type indices;
std::transform(strides.begin(),
strides.end(),
lens.begin(),
indices.begin(),
[&](size_t stride, size_t len) { return (i / stride) % len; });
migraphx::unpack(f, indices);
});
}
else
{
dfor(xs...)(f);
}
};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
void par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
f(i);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::generate(threads.begin(), threads.end(), [=, &work] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
f(i);
}
});
work += grainsize;
return result;
});
assert(work >= n);
}
}
template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize =
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain);
par_for_impl(n, threadsize, f);
}
template <class F>
void par_for(std::size_t n, F f)
{
const int min_grain = 8;
par_for(n, min_grain, f);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -80,6 +80,9 @@ struct onnx_parser ...@@ -80,6 +80,9 @@ struct onnx_parser
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat); add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
} }
...@@ -356,6 +359,18 @@ struct onnx_parser ...@@ -356,6 +359,18 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::size_t axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
op::gather op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -525,6 +540,99 @@ struct onnx_parser ...@@ -525,6 +540,99 @@ struct onnx_parser
return prog.add_instruction(migraphx::op::transpose{perm}, args.front()); return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
} }
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
return prog.add_literal(migraphx::literal{s, vec_shape});
}
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
instruction_ref parse_constant_fill(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
if(contains(attributes, "dtype"))
{
dtype = parse_value(attributes.at("dtype")).at<int>();
}
migraphx::shape::type_t type = get_type(dtype);
if(contains(attributes, "input_as_shape"))
{
input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(contains(attributes, "extra_shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
}
if(input_as_shape == 1)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
}
if(contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time");
}
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else if(input_as_shape == 0)
{
if(!contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
literal ls = parse_value(attributes.at("shape"));
std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else
{
MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
}
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
...@@ -774,6 +882,28 @@ struct onnx_parser ...@@ -774,6 +882,28 @@ struct onnx_parser
}); });
return {shape_type, dims}; return {shape_type, dims};
} }
shape::type_t get_type(int dtype)
{
switch(dtype)
{
case 1: return shape::float_type;
case 2: return shape::uint8_type;
case 3: return shape::int8_type;
case 4: return shape::uint16_type;
case 5: return shape::int16_type;
case 6: return shape::int32_type;
case 7: return shape::int64_type;
case 10: return shape::half_type;
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default:
{
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
}; };
program parse_onnx(const std::string& name) program parse_onnx(const std::string& name)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/cpu/gemm.hpp> #include <migraphx/cpu/gemm.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -72,7 +73,7 @@ struct cpu_batch_norm_inference ...@@ -72,7 +73,7 @@ struct cpu_batch_norm_inference
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)( visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)( par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c) + epsilon) > 0); assert((variance(c) + epsilon) > 0);
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) / result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) /
...@@ -87,7 +88,7 @@ struct cpu_batch_norm_inference ...@@ -87,7 +88,7 @@ struct cpu_batch_norm_inference
visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)( visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)( par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c, h, w) + epsilon) > 0); assert((variance(c, h, w) + epsilon) > 0);
result(n, c, h, w) = gamma(c, h, w) * result(n, c, h, w) = gamma(c, h, w) *
...@@ -122,7 +123,7 @@ struct cpu_convolution ...@@ -122,7 +123,7 @@ struct cpu_convolution
auto wei_h = wei[2]; auto wei_h = wei[2];
auto wei_w = wei[3]; auto wei_w = wei[3];
dfor(output_shape.lens()[0], par_dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
...@@ -245,7 +246,7 @@ struct cpu_pooling ...@@ -245,7 +246,7 @@ struct cpu_pooling
auto in_h = input.get_shape().lens()[2]; auto in_h = input.get_shape().lens()[2];
auto in_w = input.get_shape().lens()[3]; auto in_w = input.get_shape().lens()[3];
dfor(output_shape.lens()[0], par_dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
...@@ -322,6 +323,18 @@ struct cpu_gemm ...@@ -322,6 +323,18 @@ struct cpu_gemm
} }
}; };
struct cpu_gather
{
op::gather op;
std::string name() const { return "cpu::gather"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct identity_op struct identity_op
{ {
std::string name() const { return "cpu::identity"; } std::string name() const { return "cpu::identity"; }
...@@ -651,6 +664,7 @@ struct cpu_apply ...@@ -651,6 +664,7 @@ struct cpu_apply
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>(); apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["gather"] = extend_op<cpu_gather, op::gather>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>(); apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>(); apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>(); apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
......
...@@ -28,6 +28,7 @@ add_library(migraphx_device ...@@ -28,6 +28,7 @@ add_library(migraphx_device
device/contiguous.cpp device/contiguous.cpp
device/mul.cpp device/mul.cpp
device/concat.cpp device/concat.cpp
device/gather.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
...@@ -56,6 +57,7 @@ add_library(migraphx_gpu ...@@ -56,6 +57,7 @@ add_library(migraphx_gpu
sigmoid.cpp sigmoid.cpp
abs.cpp abs.cpp
elu.cpp elu.cpp
gather.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis)
{
visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data());
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i);
lens[axis] = indices_ptr[lens[axis]];
outptr[i] = inptr[desc_input.linear(lens)];
});
});
});
});
return args.back();
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/gather.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_gather::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_gather::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
return device::gather(ctx.get_stream().get(), output_shape, args, op.axis);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_gather
{
op::gather op;
std::string name() const { return "gpu::gather"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <migraphx/gpu/pooling.hpp> #include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/concat.hpp> #include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/gather.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -90,7 +91,7 @@ struct miopen_apply ...@@ -90,7 +91,7 @@ struct miopen_apply
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax"); add_extend_op<miopen_softmax, op::softmax>("softmax");
add_extend_op<hip_gather, op::gather>("gather");
add_convolution_op(); add_convolution_op();
add_pooling_op(); add_pooling_op();
add_batch_norm_inference_op(); add_batch_norm_inference_op();
......
...@@ -101,6 +101,49 @@ TEST_CASE(concat_test) ...@@ -101,6 +101,49 @@ TEST_CASE(concat_test)
} }
} }
TEST_CASE(gather_test)
{
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
}
TEST_CASE(squeeze_test) TEST_CASE(squeeze_test)
{ {
{ {
......
...@@ -934,6 +934,22 @@ struct test_concat_relu ...@@ -934,6 +934,22 @@ struct test_concat_relu
} }
}; };
struct test_gather
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
void manual_identity() void manual_identity()
{ {
migraphx::program p; migraphx::program p;
......
gather-example:Ž
'
data
indicesy"Gather*
axis  test_gatherZ
data




Z
indices


b
y




B
\ No newline at end of file
...@@ -400,6 +400,45 @@ TEST_CASE(reshape_test) ...@@ -400,6 +400,45 @@ TEST_CASE(reshape_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(shape_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto l0 = p.add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
p.add_literal(s_shape, l0->get_shape().lens());
auto prog = migraphx::parse_onnx("shape_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
std::size_t axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(shape_gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {7, 3, 10}});
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
migraphx::shape const_shape{migraphx::shape::int32_type, {1}};
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
std::size_t axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = migraphx::parse_onnx("shape_gather.onnx");
EXPECT(p == prog);
}
TEST_CASE(flatten_test) TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -455,6 +494,33 @@ TEST_CASE(constant_test) ...@@ -455,6 +494,33 @@ TEST_CASE(constant_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(constant_fill_test)
{
{
migraphx::program p;
auto l0 = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2}}, {2, 3}});
std::vector<std::size_t> dims(l0->get_shape().elements());
migraphx::literal ls = l0->get_literal();
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{migraphx::shape::float_type, dims};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("const_fill1.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("const_fill2.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gemm_test) TEST_CASE(gemm_test)
{ {
migraphx::program p; migraphx::program p;
......
 shape-example:I
xy"Shape
test_shapeZ
x




b
y

B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment