Unverified Commit ad414ba9 authored by Scott Thornton's avatar Scott Thornton Committed by GitHub
Browse files

Merge pull request #88 from ROCmSoftwarePlatform/concat

Concat
parents 945e89e0 96a3a8be
...@@ -314,6 +314,57 @@ struct contiguous ...@@ -314,6 +314,57 @@ struct contiguous
} }
}; };
struct concat
{
std::size_t axis = 0;
std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument> args) const
{
std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
offset[axis] = 0;
for(const auto& arg : args)
{
offsets.push_back(output_shape.index(offset));
offset[axis] += arg.get_shape().lens()[axis];
}
return offsets;
}
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
{
MIGRAPH_THROW("Number of input tensors should exceed 0");
}
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{
if(l != axis)
{
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l];
}))
{
MIGRAPH_THROW("Non-axis dimensions should match");
}
}
}
std::size_t new_dim_axis = 0;
for(const auto& input : inputs)
{
const auto& lens = input.lens();
new_dim_axis += lens[axis];
}
std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis;
return {type, new_lens};
}
};
struct slice struct slice
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
......
...@@ -282,6 +282,34 @@ struct cpu_contiguous ...@@ -282,6 +282,34 @@ struct cpu_contiguous
} }
}; };
struct cpu_concat
{
op::concat op;
std::string name() const { return "cpu::concat"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
std::vector<std::size_t> coffsets = op.compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
});
}
return result;
}
};
struct cpu_gemm struct cpu_gemm
{ {
op::gemm op; op::gemm op;
...@@ -568,20 +596,20 @@ struct cpu_apply ...@@ -568,20 +596,20 @@ struct cpu_apply
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>(); apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>(); apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>();
apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>(); apply_map["sigmoid"] = simple_op<cpu_unary<sigmoid_op>>();
apply_map["sigmoid"] = simple_op<cpu_unary<sigmoid_op>>(); apply_map["exp"] = simple_op<cpu_unary<exp_op>>();
apply_map["exp"] = simple_op<cpu_unary<exp_op>>(); apply_map["neg"] = simple_op<cpu_unary<neg_op>>();
apply_map["neg"] = simple_op<cpu_unary<neg_op>>(); apply_map["sin"] = simple_op<cpu_unary<sin_op>>();
apply_map["sin"] = simple_op<cpu_unary<sin_op>>(); apply_map["cos"] = simple_op<cpu_unary<cos_op>>();
apply_map["cos"] = simple_op<cpu_unary<cos_op>>(); apply_map["tan"] = simple_op<cpu_unary<tan_op>>();
apply_map["tan"] = simple_op<cpu_unary<tan_op>>(); apply_map["add"] = simple_op<cpu_binary<add_op>>();
apply_map["add"] = simple_op<cpu_binary<add_op>>(); apply_map["sub"] = simple_op<cpu_binary<sub_op>>();
apply_map["sub"] = simple_op<cpu_binary<sub_op>>(); apply_map["mul"] = simple_op<cpu_binary<mul_op>>();
apply_map["mul"] = simple_op<cpu_binary<mul_op>>(); apply_map["div"] = simple_op<cpu_binary<div_op>>();
apply_map["div"] = simple_op<cpu_binary<div_op>>();
apply_map["softmax"] = simple_op<softmax2d>(); apply_map["softmax"] = simple_op<softmax2d>();
} }
......
...@@ -14,6 +14,7 @@ add_library(migraph_device ...@@ -14,6 +14,7 @@ add_library(migraph_device
device/add.cpp device/add.cpp
device/add_relu.cpp device/add_relu.cpp
device/contiguous.cpp device/contiguous.cpp
device/concat.cpp
) )
rocm_clang_tidy_check(migraph_device) rocm_clang_tidy_check(migraph_device)
target_link_libraries(migraph_device migraph hip::device) target_link_libraries(migraph_device migraph hip::device)
...@@ -31,6 +32,7 @@ add_library(migraph_gpu ...@@ -31,6 +32,7 @@ add_library(migraph_gpu
convolution.cpp convolution.cpp
softmax.cpp softmax.cpp
contiguous.cpp contiguous.cpp
concat.cpp
relu.cpp relu.cpp
leaky_relu.cpp leaky_relu.cpp
add.cpp add.cpp
......
#include <migraph/gpu/concat.hpp>
#include <migraph/operators.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/device/concat.hpp>
#include <utility>
namespace migraph {
namespace gpu {
shape hip_concat::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument
hip_concat::compute(context&, const shape& output_shape, const std::vector<argument>& args) const
{
std::vector<std::size_t> offsets = op.compute_offsets(output_shape, args);
return device::concat(output_shape, args, offsets);
}
} // namespace gpu
} // namespace migraph
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/gpu/device/concat.hpp>
#include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp>
namespace migraph {
namespace gpu {
namespace device {
argument concat(const migraph::shape& output_shape,
std::vector<migraph::argument> args,
std::vector<std::size_t> offsets)
{
// migraph::argument& result = args.back();
for(std::size_t l = 0; l < args.size() - 1; l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(args.back(), argl)([&](auto output, auto input) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto* outptr = output.data() + offsets[l];
const auto* inptr = input.data();
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(nelements)(
[=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
});
});
}
// return result;
return args.back();
}
} // namespace device
} // namespace gpu
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_CONCAT_HPP
#include <migraph/gpu/lowering.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/dfor.hpp>
#include <migraph/gpu/device/concat.hpp>
#include <migraph/gpu/device/add.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/context.hpp>
#include <utility>
namespace migraph {
namespace gpu {
struct hip_concat
{
op::concat op;
std::string name() const { return "gpu::concat"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
};
} // namespace gpu
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_DEVICE_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_CONCAT_HPP
namespace migraph {
namespace gpu {
namespace device {
argument
concat(const shape& output_shape, std::vector<argument> args, std::vector<std::size_t> offsets);
} // namespace device
} // namespace gpu
} // namespace migraph
#endif
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <migraph/gpu/batchnorm.hpp> #include <migraph/gpu/batchnorm.hpp>
#include <migraph/gpu/pooling.hpp> #include <migraph/gpu/pooling.hpp>
#include <migraph/gpu/gemm.hpp> #include <migraph/gpu/gemm.hpp>
#include <migraph/gpu/concat.hpp>
#include <utility> #include <utility>
namespace migraph { namespace migraph {
...@@ -72,6 +73,10 @@ struct miopen_apply ...@@ -72,6 +73,10 @@ struct miopen_apply
{ {
check_shape(s, apply_contiguous(it)); check_shape(s, apply_contiguous(it));
} }
else if(it->name() == "concat")
{
check_shape(s, apply_concat(it));
}
else if(it->name() == "batch_norm_inference") else if(it->name() == "batch_norm_inference")
{ {
check_shape(s, apply_batch_norm_inference(it)); check_shape(s, apply_batch_norm_inference(it));
...@@ -173,6 +178,15 @@ struct miopen_apply ...@@ -173,6 +178,15 @@ struct miopen_apply
return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output); return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output);
} }
instruction_ref apply_concat(instruction_ref ins)
{
auto&& op = any_cast<op::concat>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return prog->replace_instruction(ins, hip_concat{op}, refs);
}
instruction_ref apply_batch_norm_inference(instruction_ref ins) instruction_ref apply_batch_norm_inference(instruction_ref ins)
{ {
auto&& op = any_cast<op::batch_norm_inference>(ins->get_operator()); auto&& op = any_cast<op::batch_norm_inference>(ins->get_operator());
......
...@@ -47,6 +47,56 @@ void slice_test() ...@@ -47,6 +47,56 @@ void slice_test()
} }
} }
void concat_test()
{
{
migraph::program p;
std::size_t axis = 1;
std::vector<int> data0 = {0, 1, 5, 6};
std::vector<int> data1 = {2, 3, 4, 7, 8, 9};
std::vector<int> data2 = {10, 20};
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {2, 3}};
migraph::shape s2{migraph::shape::int32_type, {2, 1}};
auto l0 = p.add_literal(migraph::literal{s0, data0});
auto l1 = p.add_literal(migraph::literal{s1, data1});
auto l2 = p.add_literal(migraph::literal{s2, data2});
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
std::vector<int> gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20};
std::vector<int> results_vector(2 * 6);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(migraph::verify_range(result.get_shape().lens(), std::vector<std::size_t>({2, 6})));
EXPECT(
migraph::verify_range(result.get_shape().strides(), std::vector<std::size_t>({6, 1})));
}
{
migraph::program p;
std::size_t axis = 0;
std::vector<int> data0 = {0, 1, 2, 3};
std::vector<int> data1 = {4, 5, 6, 7, 8, 9};
std::vector<int> data2 = {10, 11};
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {3, 2}};
migraph::shape s2{migraph::shape::int32_type, {1, 2}};
auto l0 = p.add_literal(migraph::literal{s0, data0});
auto l1 = p.add_literal(migraph::literal{s1, data1});
auto l2 = p.add_literal(migraph::literal{s2, data2});
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
std::vector<int> results_vector(6 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(migraph::verify_range(result.get_shape().lens(), std::vector<std::size_t>({6, 2})));
EXPECT(
migraph::verify_range(result.get_shape().strides(), std::vector<std::size_t>({2, 1})));
}
}
void squeeze_test() void squeeze_test()
{ {
{ {
...@@ -933,6 +983,7 @@ void contiguous_test() ...@@ -933,6 +983,7 @@ void contiguous_test()
int main() int main()
{ {
concat_test();
slice_test(); slice_test();
squeeze_test(); squeeze_test();
unsqueeze_test(); unsqueeze_test();
......
...@@ -36,9 +36,9 @@ void fwd_conv_batchnorm_rewrite_test() ...@@ -36,9 +36,9 @@ void fwd_conv_batchnorm_rewrite_test()
auto create_program = [&]() { auto create_program = [&]() {
migraph::program p; migraph::program p;
auto x = p.add_literal(xs, xdata); auto x = p.add_literal(xs, xdata);
auto w = p.add_literal(ws, wdata); auto w = p.add_literal(ws, wdata);
auto conv = p.add_instruction(migraph::op::convolution{{0, 0}, {1, 1}, {1, 1}}, x, w); auto conv = p.add_instruction(migraph::op::convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, x, w);
auto scale = p.add_literal(migraph::literal{vars, {3.0f}}); auto scale = p.add_literal(migraph::literal{vars, {3.0f}});
auto bias = p.add_literal(migraph::literal{vars, {8.1f}}); auto bias = p.add_literal(migraph::literal{vars, {8.1f}});
auto mean = p.add_literal(migraph::literal{vars, {4.0f}}); auto mean = p.add_literal(migraph::literal{vars, {4.0f}});
......
...@@ -577,6 +577,40 @@ struct test_conv_bn_relu_pooling ...@@ -577,6 +577,40 @@ struct test_conv_bn_relu_pooling
} }
}; };
struct test_concat
{
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 1;
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {2, 3}};
migraph::shape s2{migraph::shape::int32_type, {2, 1}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat2
{
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 0;
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {3, 2}};
migraph::shape s2{migraph::shape::int32_type, {1, 2}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_conv_bn_relu_pooling2 struct test_conv_bn_relu_pooling2
{ {
static migraph::instruction_ref static migraph::instruction_ref
...@@ -615,6 +649,8 @@ struct test_conv_bn_relu_pooling2 ...@@ -615,6 +649,8 @@ struct test_conv_bn_relu_pooling2
int main() int main()
{ {
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_add>(); verify_program<test_add>();
verify_program<test_triadd>(); verify_program<test_triadd>();
verify_program<test_triadd2>(); verify_program<test_triadd2>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment