Commit 0d8a9768 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change back to use the onnx::gather operator.

parent eb9d3a01
......@@ -633,7 +633,6 @@ struct as_shape
int output_alias(const std::vector<shape>&) const { return 0; }
};
// Gather to use the algorithm in onnx::gather operator
struct gather
{
std::size_t axis = 0;
......@@ -684,64 +683,6 @@ struct gather
int output_alias(const std::vector<shape>&) const { return 0; }
};
// Gather to use the algorithm in torch.nn.gather, which is diffrent
// from the onnx::gather operator.
struct gather_torch
{
std::size_t axis = 0;
std::string name() const { return "gather_torch"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens();
if(axis >= lens.size())
{
MIGRAPHX_THROW("Gather, axis is out of range.");
}
auto type = inputs[0].type();
// output shape is the same as that of the indices
return {type, inputs[1].lens()};
}
template <class T>
void compute_index(const T& out_idx, const std::vector<argument>& args, T& in_idx) const
{
in_idx = out_idx;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
args[1].visit([&](auto idx) {
std::size_t i = idx(out_idx.begin(), out_idx.end());
if(i >= max_dim)
{
MIGRAPHX_THROW("gather_torch, indices are out of range in input tensor");
}
in_idx[axis] = i;
});
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
std::vector<std::size_t> in_idx;
this->compute_index(out_idx, args, in_idx);
std::cout << "gather torch input = " << input(in_idx.begin(), in_idx.end())
<< std::endl;
output(out_idx.begin(), out_idx.end()) = input(in_idx.begin(), in_idx.end());
std::cout << "gather torch out = " << output(out_idx.begin(), out_idx.end())
<< std::endl;
});
});
return result;
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct dot
{
float alpha = 1.0;
......
......@@ -367,7 +367,7 @@ struct onnx_parser
{
axis = parse_value(attributes.at("axis")).at<int>();
}
op::gather_torch op{axis};
op::gather op{axis};
return prog.add_instruction(op, std::move(args));
}
......
......@@ -334,18 +334,6 @@ struct cpu_gather
}
};
struct cpu_gather_torch
{
op::gather_torch op;
std::string name() const { return "cpu::gather_torch"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct identity_op
{
std::string name() const { return "cpu::identity"; }
......@@ -675,9 +663,7 @@ struct cpu_apply
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>();
// To support the rnn from pytorch, we need to use the algorithm
// of gather in torch.nn.gather
apply_map["gather"] = extend_op<cpu_gather_torch, op::gather_torch>();
apply_map["gather"] = extend_op<cpu_gather, op::gather>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
......
......@@ -37,33 +37,6 @@ argument gather(hipStream_t stream,
return args.back();
}
argument gather_torch(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis)
{
visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data());
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
hip_tensor_descriptor<ndim> desc_ind(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i);
lens[axis] = indices_ptr[desc_ind.linear(lens)];
outptr[i] = inptr[desc_input.linear(lens)];
});
});
});
});
return args.back();
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -22,19 +22,6 @@ argument hip_gather::compute(context& ctx,
return device::gather(ctx.get_stream().get(), output_shape, args, op.axis);
}
shape hip_gather_torch::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_gather_torch::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
return device::gather_torch(ctx.get_stream().get(), output_shape, args, op.axis);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -10,18 +10,11 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
// use algorithm of onnx::gather (not used for now)
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis);
// use algorithm of torch.nn.gather
argument gather_torch(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -22,7 +22,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// use algorithm of onnx::gather (not use for now)
struct hip_gather
{
op::gather op;
......@@ -33,17 +32,6 @@ struct hip_gather
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
// use algorithm of torch.nn.gather
struct hip_gather_torch
{
op::gather_torch op;
std::string name() const { return "gpu::gather_torch"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -91,7 +91,7 @@ struct miopen_apply
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax");
add_extend_op<hip_gather_torch, op::gather_torch>("gather");
add_extend_op<hip_gather, op::gather>("gather");
add_convolution_op();
add_pooling_op();
add_batch_norm_inference_op();
......
......@@ -114,11 +114,11 @@ TEST_CASE(gather_test)
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0;
p.add_instruction(migraphx::op::gather_torch{axis}, a0, a1);
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 7.5f};
std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
......@@ -134,11 +134,11 @@ TEST_CASE(gather_test)
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 1;
p.add_instruction(migraphx::op::gather_torch{axis}, a0, a1);
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 2.5f};
std::vector<float> golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
......
......@@ -945,7 +945,7 @@ struct test_gather
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0;
p.add_instruction(migraphx::op::gather_torch{axis}, a0, a1);
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
......
gather-example:
gather-example:Ž
'
data
indicesy"Gather*
......@@ -8,16 +8,14 @@



Z!
indices

Z
indices




b
b
y




B
\ No newline at end of file


B
\ No newline at end of file
......@@ -404,10 +404,9 @@ TEST_CASE(gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 =
p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
std::size_t axis = 1;
p.add_instruction(migraphx::op::gather_torch{axis}, l0, l1);
p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx");
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment