Unverified Commit 63563da2 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add dimensionality reduction to device functions (#587)



* Add reduce dims

* Formatting

* Reduce dims on the gpu

* Formatting

* Fix tidy issues

* Convert to assert

* Reduce dims for contiguous

* Formatting

* Remove move

* Fix arguments used

* Formatting

* Fix warnings

* Formatting
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
parent a5648d9c
......@@ -21,6 +21,7 @@ add_library(migraphx
instruction.cpp
program.cpp
quantization.cpp
reduce_dims.cpp
remap.cpp
shape.cpp
schedule.cpp
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_DIMS_HPP
#define MIGRAPHX_GUARD_RTGLIB_REDUCE_DIMS_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<shape> reduce_dims(const std::vector<shape>& shapes);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/reduce_dims.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
{
std::vector<std::size_t> new_lens;
for(const auto& s : shapes)
{
assert(n < s.lens().size());
if((n + 1) >= s.lens().size())
return false;
auto astride = s.strides()[n];
auto alen = s.lens()[n];
auto bstride = s.strides()[n + 1];
auto blen = s.lens()[n + 1];
if(astride == bstride * blen)
{
new_lens.push_back(alen * blen);
}
}
if(new_lens.size() != shapes.size())
return false;
std::size_t i = 0;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.erase(lens.begin() + n);
strides.erase(strides.begin() + n);
lens[n] = new_lens[i];
s = shape{s.type(), lens, strides};
i++;
}
return true;
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size())
{
}
return n + 1;
}
void reduce_dim_all(std::vector<shape>& shapes)
{
std::size_t n = 0;
while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n);
}
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
{
return std::accumulate(
shapes.begin() + 1, shapes.end(), shapes.front().lens(), [](auto&& lens, auto&& s) {
std::vector<std::size_t> result;
const auto* x = &s.lens();
const auto* y = &lens;
if(x->size() > y->size())
std::swap(x, y);
std::transform(
x->begin(), x->end(), y->begin(), std::back_inserter(result), [&](auto a, auto b) {
return std::max(a, b);
});
return result;
});
}
shape mask_shape(const shape& s, const std::vector<std::size_t>& lens)
{
assert(s.lens().size() == lens.size());
std::vector<std::size_t> rstrides(lens.size());
std::size_t stride = 1;
for(std::size_t i = lens.size() - 1; i < lens.size(); i--)
{
if(lens[i] == s.lens()[i])
{
rstrides[i] = stride;
stride *= lens[i];
}
else if(lens[i] != 1 and s.lens()[i] != 1)
{
return shape{};
}
}
return shape{s.type(), lens, rstrides};
}
std::vector<shape> reduce_dims(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
auto result = shapes;
auto base = base_lens(shapes);
for(auto&& s : shapes)
{
if(s.lens().size() != base.size())
return shapes;
if(s.lens() == base)
continue;
auto mshape = mask_shape(s, base);
if(mshape.lens().size() != base.size())
return shapes;
result.push_back(mshape);
}
reduce_dim_all(result);
result.erase(result.begin() + shapes.size(), result.end());
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -105,7 +105,6 @@ add_library(migraphx_gpu
quant_convolution.cpp
softmax.cpp
logsoftmax.cpp
contiguous.cpp
concat.cpp
leaky_relu.cpp
batchnorm.cpp
......
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_contiguous::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)});
}
argument miopen_contiguous::compute(context& ctx,
shape output_shape,
const std::vector<argument>& args) const
{
assert(output_shape == args[1].get_shape());
assert(output_shape.standard());
(void)output_shape;
device::contiguous(ctx.get_stream().get(), args.at(1), args.at(0));
return args.at(1);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -22,7 +22,7 @@ argument concat(hipStream_t stream,
auto output_shape = shape{
arg.get_shape().type(), arg.get_shape().lens(), args.back().get_shape().strides()};
auto output = argument{output_shape, args.back().data() + byte_offset};
contiguous(stream, std::move(output), arg);
contiguous(stream, output, arg);
}
return args.back();
}
......
......@@ -7,9 +7,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void contiguous(hipStream_t stream, argument result, argument arg)
void contiguous(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, std::move(result), std::move(arg))([](auto x) __device__ { return x; });
nary(stream, result, arg)([](auto x) __device__ { return x; });
}
} // namespace device
......
......@@ -34,17 +34,7 @@ constexpr void visit_tensor_size(index_int n, F f)
f(std::integral_constant<index_int, 4>{});
break;
}
case 5:
{
f(std::integral_constant<index_int, 5>{});
break;
}
case 6:
{
f(std::integral_constant<index_int, 6>{});
break;
}
default: throw std::runtime_error("Tensor size dim out of range");
default: throw std::runtime_error("Tensor dims " + std::to_string(n) + " out of range");
}
}
......
......@@ -169,68 +169,16 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
contains({{0, 0}, {1, 1}}, op.stride) and contains({{1, 1}}, op.dilation);
}
struct hip_triadd
struct hip_triadd : ternary_device<hip_triadd, &device::add>
{
std::string name() const { return "hip::triadd"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_triadd_clip
struct hip_triadd_clip : quinary_device<hip_triadd_clip, &device::add_clip>
{
std::string name() const { return "hip::triadd_clip"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(6);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add_clip(ctx.get_stream().get(),
args.at(5),
args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
return args.at(5);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_add_clip
struct hip_add_clip : quaternary_device<hip_add_clip, &device::add_clip>
{
std::string name() const { return "hip::add_clip"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(5);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add_clip(
ctx.get_stream().get(), args.at(4), args.at(0), args.at(1), args.at(2), args.at(3));
return args.at(4);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu>
......@@ -273,43 +221,12 @@ struct hip_add_gelu_new : binary_device<hip_add_gelu_new, &device::add_gelu_new>
{
};
struct hip_mul_add
struct hip_mul_add : ternary_device<hip_mul_add, &device::mul_add>
{
std::string name() const { return "hip::mul_add"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_mul_add_relu
struct hip_mul_add_relu : ternary_device<hip_mul_add_relu, &device::mul_add_relu>
{
std::string name() const { return "hip::mul_add_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add_relu(
ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
void move_broadcasted_back(std::vector<instruction_ref>& args)
......
......@@ -3,6 +3,8 @@
#include <migraphx/shape.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -10,22 +12,17 @@ namespace gpu {
struct context;
struct miopen_contiguous
struct miopen_contiguous : unary_device<miopen_contiguous, &device::contiguous>
{
op::contiguous op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::contiguous"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context&, shape output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
shape compute_shape(const std::vector<shape>& inputs) const
{
return shapes.size() - 1;
check_shapes{inputs, *this}.has(2);
if(inputs.front().standard())
return inputs.front();
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return {t, lens};
}
};
......
......@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void contiguous(hipStream_t stream, argument result, argument arg);
void contiguous(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
......
......@@ -7,6 +7,7 @@
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/type_name.hpp>
#include <utility>
#include <iostream>
......@@ -15,95 +16,126 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <class Derived, void (*F)(hipStream_t, const argument&, const argument&)>
struct unary_device : oper<Derived>
template <class Derived, std::size_t N>
struct device_base : oper<Derived>
{
shape compute_shape(const std::vector<shape>& inputs) const
template <class Self, class F>
static auto reflect(Self&, F)
{
check_shapes{inputs, *this}.has(2);
auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
return pack();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
std::vector<shape> reduce_shapes;
void finalize(context&, const shape&, const std::vector<shape>& inputs)
{
F(ctx.get_stream().get(), args[1], args[0]);
return args[1];
reduce_shapes = reduce_dims(inputs);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
argument get_arg(const std::vector<argument>& args, std::size_t i) const
{
return shapes.size() - 1;
if(reduce_shapes.empty())
return args[i];
return args.at(i).reshape(reduce_shapes.at(i));
}
};
template <class Derived, void (*F)(hipStream_t, const argument&, const argument&, const argument&)>
struct binary_device : oper<Derived>
{
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
check_shapes{inputs, *this}.has(N + 1);
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
if(std::all_of(inputs.begin(), inputs.end() - 1, [&](auto s) { return s == s0; }) and
s0.packed())
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
template <class Derived, void (*F)(hipStream_t, const argument&, const argument&)>
struct unary_device : device_base<Derived, 1>
{
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
F(ctx.get_stream().get(), args[2], args[0], args[1]);
return args[2];
F(ctx.get_stream().get(), this->get_arg(args, 1), this->get_arg(args, 0));
return args[1];
}
};
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
template <class Derived, void (*F)(hipStream_t, const argument&, const argument&, const argument&)>
struct binary_device : device_base<Derived, 2>
{
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return shapes.size() - 1;
F(ctx.get_stream().get(),
this->get_arg(args, 2),
this->get_arg(args, 0),
this->get_arg(args, 1));
return args[2];
}
};
template <class Derived,
void (*F)(
hipStream_t, const argument&, const argument&, const argument&, const argument&)>
struct ternary_device : oper<Derived>
struct ternary_device : device_base<Derived, 3>
{
shape compute_shape(const std::vector<shape>& inputs) const
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
check_shapes{inputs, *this}.has(4);
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s0 == s1 and s1 == s2 and s0.packed())
{
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
F(ctx.get_stream().get(),
this->get_arg(args, 3),
this->get_arg(args, 0),
this->get_arg(args, 1),
this->get_arg(args, 2));
return args[3];
}
};
template <class Derived,
void (*F)(hipStream_t,
const argument&,
const argument&,
const argument&,
const argument&,
const argument&)>
struct quaternary_device : device_base<Derived, 4>
{
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
F(ctx.get_stream().get(), args[3], args[0], args[1], args[2]);
return args[3];
F(ctx.get_stream().get(),
this->get_arg(args, 4),
this->get_arg(args, 0),
this->get_arg(args, 1),
this->get_arg(args, 2),
this->get_arg(args, 3));
return args[4];
}
};
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
template <class Derived,
void (*F)(hipStream_t,
const argument&,
const argument&,
const argument&,
const argument&,
const argument&,
const argument&)>
struct quinary_device : device_base<Derived, 5>
{
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return shapes.size() - 1;
F(ctx.get_stream().get(),
this->get_arg(args, 5),
this->get_arg(args, 0),
this->get_arg(args, 1),
this->get_arg(args, 2),
this->get_arg(args, 3),
this->get_arg(args, 4));
return args[5];
}
};
......
......@@ -169,8 +169,8 @@ struct miopen_apply
add_generic_op<hip_ceil>("ceil");
add_generic_op<hip_floor>("floor");
add_generic_op<hip_recip>("recip");
add_generic_op<miopen_contiguous>("contiguous");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<hip_softmax, op::softmax>("softmax");
add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax");
......
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include "test.hpp"
migraphx::shape make_shape(std::vector<std::size_t> lens)
{
return {migraphx::shape::float_type, std::move(lens)};
}
migraphx::shape make_shape(std::vector<std::size_t> lens, std::vector<std::size_t> strides)
{
return {migraphx::shape::float_type, std::move(lens), std::move(strides)};
}
TEST_CASE(same_standard)
{
auto is = make_shape({64, 3, 7, 7});
auto os = make_shape({64 * 3 * 7 * 7});
std::vector<migraphx::shape> ishapes = {is, is, is};
std::vector<migraphx::shape> eshapes = {os, os, os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(same_broadcast1)
{
auto is = make_shape({64, 3, 7, 7});
auto os = make_shape({64, 3, 7 * 7});
std::vector<migraphx::shape> ishapes = {is, make_shape({64, 3, 7, 7}, {0, 1, 0, 0}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({64, 3, 7 * 7}, {0, 1, 0}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(same_broadcast2)
{
auto is = make_shape({64, 3, 8, 7, 7});
auto os = make_shape({64, 8 * 3, 7 * 7});
std::vector<migraphx::shape> ishapes = {is, make_shape({64, 3, 8, 7, 7}, {0, 8, 1, 0, 0}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({64, 8 * 3, 7 * 7}, {0, 1, 0}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(same_transposed)
{
auto is = make_shape({64, 3, 7, 7});
auto os = make_shape({64 * 3, 7, 7});
std::vector<migraphx::shape> ishapes = {is, migraphx::reorder_shape(is, {0, 1, 3, 2}), is};
std::vector<migraphx::shape> eshapes = {os, migraphx::reorder_shape(os, {0, 2, 1}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(different_masked1)
{
auto is = make_shape({64, 3, 7, 7});
auto os = make_shape({64, 3, 7 * 7});
std::vector<migraphx::shape> ishapes = {is, make_shape({1, 3, 1, 1}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({1, 3, 1}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(different_masked2)
{
auto is = make_shape({64, 3, 7, 7});
auto os = make_shape({64, 3, 7 * 7});
std::vector<migraphx::shape> ishapes = {
is, make_shape({1, 3, 1, 1}), make_shape({64, 1, 7, 7})};
std::vector<migraphx::shape> eshapes = {os, make_shape({1, 3, 1}), make_shape({64, 1, 7 * 7})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(different_incompatible)
{
auto is = make_shape({64, 3, 7, 7});
std::vector<migraphx::shape> ishapes = {is, make_shape({1, 3, 2, 1}), is};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(ishapes == rshapes);
}
TEST_CASE(different_ranks)
{
auto is = make_shape({64, 3, 7, 7});
std::vector<migraphx::shape> ishapes = {is, make_shape({1, 3}), is};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(ishapes == rshapes);
}
TEST_CASE(transposed1)
{
std::vector<migraphx::shape> ishapes = {
make_shape({8, 28, 4, 56, 56}),
make_shape({8, 28, 4, 56, 56}, {351232, 3136, 87808, 56, 1})};
std::vector<migraphx::shape> eshapes = {
make_shape({8, 28, 4, 56 * 56}), make_shape({8, 28, 4, 56 * 56}, {351232, 3136, 87808, 1})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(empty)
{
auto rshapes = migraphx::reduce_dims({});
EXPECT(rshapes.empty());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment