"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "d9d35def3db5ebfd02e26173458e207b2732ecfd"
Commit 447eec93 authored by wsttiger's avatar wsttiger
Browse files

Added concat operator for gpu (and tests)

parent b23aab5a
......@@ -14,6 +14,7 @@ add_library(migraph_device
device/add.cpp
device/add_relu.cpp
device/contiguous.cpp
device/concat.cpp
)
rocm_clang_tidy_check(migraph_device)
target_link_libraries(migraph_device migraph hip::device)
......@@ -31,6 +32,7 @@ add_library(migraph_gpu
convolution.cpp
softmax.cpp
contiguous.cpp
concat.cpp
relu.cpp
add.cpp
batchnorm.cpp
......
#include <migraph/gpu/concat.hpp>
#include <migraph/operators.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/device/concat.hpp>
#include <utility>
namespace migraph {
namespace gpu {
shape hip_concat::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
std::vector<std::size_t> hip_concat::compute_offsets(const shape& output_shape,
const std::vector<argument> args) const
{
std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
offset[op.axis] = 0;
for(const auto& arg : args)
{
offsets.push_back(output_shape.index(offset));
offset[op.axis] += arg.get_shape().lens()[op.axis];
}
return offsets;
}
argument hip_concat::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
std::vector<std::size_t> offsets = compute_offsets(output_shape, args);
return device::concat(output_shape, args, offsets);
}
} // namespace gpu
} // namespace migraph
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/gpu/device/concat.hpp>
#include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp>
namespace migraph {
namespace gpu {
namespace device {
argument concat(const migraph::shape& output_shape, std::vector<migraph::argument> args, std::vector<std::size_t> offsets)
{
//migraph::argument& result = args.back();
for(std::size_t l = 0; l < args.size()-1; l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(args.back(), argl)([&](auto output, auto input) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto* outptr = output.data() + offsets[l];
const auto* inptr = input.data();
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(nelements)([=](auto i) {
outptr[desc_output.linear(desc_input.multi(i))] = inptr[i];
});
});
});
}
//return result;
return args.back();
}
} // namespace device
} // namespace gpu
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_CONCAT_HPP
#include <migraph/gpu/lowering.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/dfor.hpp>
#include <migraph/gpu/device/concat.hpp>
#include <migraph/gpu/device/add.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/context.hpp>
#include <utility>
namespace migraph {
namespace gpu {
struct hip_concat
{
op::concat op;
std::string name() const { return "gpu::concat"; }
shape compute_shape(std::vector<shape> inputs) const;
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument> args) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
};
} // namespace gpu
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_DEVICE_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_CONCAT_HPP
namespace migraph {
namespace gpu {
namespace device {
argument concat(const shape& output_shape, std::vector<argument> args, std::vector<std::size_t> offsets);
} // namespace device
} // namespace gpu
} // namespace migraph
#endif
......@@ -21,6 +21,7 @@
#include <migraph/gpu/batchnorm.hpp>
#include <migraph/gpu/pooling.hpp>
#include <migraph/gpu/gemm.hpp>
#include <migraph/gpu/concat.hpp>
#include <utility>
namespace migraph {
......@@ -67,6 +68,10 @@ struct miopen_apply
{
check_shape(s, apply_contiguous(it));
}
else if(it->name() == "concat")
{
check_shape(s, apply_concat(it));
}
else if(it->name() == "batch_norm_inference")
{
check_shape(s, apply_batch_norm_inference(it));
......@@ -158,6 +163,15 @@ struct miopen_apply
return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output);
}
instruction_ref apply_concat(instruction_ref ins)
{
auto&& op = any_cast<op::concat>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return prog->replace_instruction(ins, hip_concat{op}, refs);
}
instruction_ref apply_batch_norm_inference(instruction_ref ins)
{
auto&& op = any_cast<op::batch_norm_inference>(ins->get_operator());
......
......@@ -518,6 +518,40 @@ struct test_conv_bn_relu_pooling
}
};
struct test_concat
{
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 1;
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {2, 3}};
migraph::shape s2{migraph::shape::int32_type, {2, 1}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat2
{
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 0;
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {3, 2}};
migraph::shape s2{migraph::shape::int32_type, {1, 2}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_conv_bn_relu_pooling2
{
static migraph::instruction_ref
......@@ -556,6 +590,8 @@ struct test_conv_bn_relu_pooling2
int main()
{
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_add>();
verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment