Commit f5e0c343 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Add the gather operator and onnx parsing of shape and constantfill.

parent 301b7605
......@@ -6,6 +6,8 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
......@@ -631,6 +633,115 @@ struct as_shape
int output_alias(const std::vector<shape>&) const { return 0; }
};
// Gather to use the algorithm in onnx::gather operator
struct gather
{
std::size_t axis = 0;
std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens();
if(axis >= lens.size())
{
MIGRAPHX_THROW("Gather, axis is out of range.");
}
auto type = inputs[0].type();
lens[axis] = inputs[1].elements();
return {type, lens};
}
template <class T>
void compute_index(const T& out_idx, const std::vector<argument>& args, T& in_idx) const
{
in_idx = out_idx;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
std::size_t idx = args[1].at<std::size_t>(out_idx[axis]);
if(idx >= max_dim)
{
MIGRAPHX_THROW("Gather, indices are out of range in input tensor");
}
in_idx[axis] = idx;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
std::vector<std::size_t> in_idx;
this->compute_index(idx, args, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
});
});
return result;
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
// Gather to use the algorithm in torch.nn.gather, which is diffrent
// from the onnx::gather operator.
struct gather_torch
{
std::size_t axis = 0;
std::string name() const { return "gather_torch"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens();
if(axis >= lens.size())
{
MIGRAPHX_THROW("Gather, axis is out of range.");
}
auto type = inputs[0].type();
// output shape is the same as that of the indices
return {type, inputs[1].lens()};
}
template <class T>
void compute_index(const T& out_idx, const std::vector<argument>& args, T& in_idx) const
{
in_idx = out_idx;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
args[1].visit([&](auto idx) {
std::size_t i = idx(out_idx.begin(), out_idx.end());
if(i >= max_dim)
{
MIGRAPHX_THROW("gather_torch, indices are out of range in input tensor");
}
in_idx[axis] = i;
});
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
std::vector<std::size_t> in_idx;
this->compute_index(out_idx, args, in_idx);
std::cout << "gather torch input = " << input(in_idx.begin(), in_idx.end())
<< std::endl;
output(out_idx.begin(), out_idx.end()) = input(in_idx.begin(), in_idx.end());
std::cout << "gather torch out = " << output(out_idx.begin(), out_idx.end())
<< std::endl;
});
});
return result;
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct dot
{
float alpha = 1.0;
......
......@@ -80,6 +80,9 @@ struct onnx_parser
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
}
......@@ -356,6 +359,18 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::size_t axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
op::gather_torch op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -525,6 +540,79 @@ struct onnx_parser
return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
}
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, attribute_map, std::vector<instruction_ref> args)
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape, operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
return prog.add_literal(migraphx::literal{s, vec_shape});
}
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
instruction_ref parse_constant_fill(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("Constantfill, MIGraphX only handle the case with 1 operand");
}
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
if(contains(attributes, "dtype"))
{
dtype = parse_value(attributes.at("dtype")).at<int>();
}
migraphx::shape::type_t type = get_type(dtype);
if(contains(attributes, "input_as_shape"))
{
input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(input_as_shape == 1)
{
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW(
"ConstantFill, cannot handle dynamic shape as input for ConstantFill");
}
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
return prog.add_literal(migraphx::literal(s, {value}));
}
else if(input_as_shape == 0)
{
std::vector<std::size_t> dims = args[0]->get_shape().lens();
migraphx::shape s{type, dims};
return prog.add_literal(migraphx::literal(s, {value}));
}
else
{
MIGRAPHX_THROW("Wrong input for ConstantFill");
}
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......@@ -774,6 +862,28 @@ struct onnx_parser
});
return {shape_type, dims};
}
shape::type_t get_type(int dtype)
{
switch(dtype)
{
case 1: return shape::float_type;
case 2: return shape::uint8_type;
case 3: return shape::int8_type;
case 4: return shape::uint16_type;
case 5: return shape::int16_type;
case 6: return shape::int32_type;
case 7: return shape::int64_type;
case 10: return shape::half_type;
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default:
{
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
};
program parse_onnx(const std::string& name)
......
......@@ -322,6 +322,30 @@ struct cpu_gemm
}
};
struct cpu_gather
{
op::gather op;
std::string name() const { return "cpu::gather"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, args);
}
};
struct cpu_gather_torch
{
op::gather_torch op;
std::string name() const { return "cpu::gather_torch"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, args);
}
};
struct identity_op
{
std::string name() const { return "cpu::identity"; }
......@@ -651,6 +675,9 @@ struct cpu_apply
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>();
// To support the rnn from pytorch, we need to use the algorithm
// of gather in torch.nn.gather
apply_map["gather"] = extend_op<cpu_gather_torch, op::gather_torch>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
......
......@@ -28,6 +28,7 @@ add_library(migraphx_device
device/contiguous.cpp
device/mul.cpp
device/concat.cpp
device/gather.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device)
......@@ -56,6 +57,7 @@ add_library(migraphx_gpu
sigmoid.cpp
abs.cpp
elu.cpp
gather.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu)
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis)
{
visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data());
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i);
lens[axis] = indices_ptr[lens[axis]];
outptr[i] = inptr[desc_input.linear(lens)];
});
});
});
});
return args.back();
}
argument gather_torch(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis)
{
visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data());
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
hip_tensor_descriptor<ndim> desc_ind(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i);
lens[axis] = indices_ptr[desc_ind.linear(lens)];
outptr[i] = inptr[desc_input.linear(lens)];
});
});
});
});
return args.back();
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/gather.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_gather::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_gather::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
return device::gather(ctx.get_stream().get(), output_shape, args, op.axis);
}
shape hip_gather_torch::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_gather_torch::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
return device::gather_torch(ctx.get_stream().get(), output_shape, args, op.axis);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
// use algorithm of onnx::gather (not used for now)
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis);
// use algorithm of torch.nn.gather
argument gather_torch(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// use algorithm of onnx::gather (not use for now)
struct hip_gather
{
op::gather op;
std::string name() const { return "gpu::gather"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
// use algorithm of torch.nn.gather
struct hip_gather_torch
{
op::gather_torch op;
std::string name() const { return "gpu::gather_torch"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -40,6 +40,7 @@
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/gather.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -90,7 +91,7 @@ struct miopen_apply
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax");
add_extend_op<hip_gather_torch, op::gather_torch>("gather");
add_convolution_op();
add_pooling_op();
add_batch_norm_inference_op();
......
......@@ -101,6 +101,49 @@ TEST_CASE(concat_test)
}
}
TEST_CASE(gather_test)
{
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0;
p.add_instruction(migraphx::op::gather_torch{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 7.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 1;
p.add_instruction(migraphx::op::gather_torch{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 2.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
}
TEST_CASE(squeeze_test)
{
{
......
......@@ -934,6 +934,22 @@ struct test_concat_relu
}
};
struct test_gather
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0;
p.add_instruction(migraphx::op::gather_torch{axis}, a0, a1);
return p;
}
};
void manual_identity()
{
migraphx::program p;
......
gather - example :–
'
data
indicesy "Gather*
axis  test_gatherZ
data




Z !
indices




b
y




B
\ No newline at end of file
......@@ -400,6 +400,19 @@ TEST_CASE(reshape_test)
EXPECT(p == prog);
}
TEST_CASE(gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 =
p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3, 4, 5}});
std::size_t axis = 1;
p.add_instruction(migraphx::op::gather_torch{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(flatten_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment