Unverified Commit 84acaea0 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dynamic shape hip::copy_to_gpu and hip::copy_from_gpu (#1694)

Updates the hip::copy_to_gpu and hip::copy_from_gpu operators to work with dynamic shapes

Allows for offload_copy to be used with dynamic batch

Changed assert in select_module because the argument might now be smaller with how offload_copy will work with dynamic batch. (maximum buffer size will be used)
parent 08360e83
...@@ -125,7 +125,7 @@ struct select_module ...@@ -125,7 +125,7 @@ struct select_module
auto ps = param_shapes.at(name); auto ps = param_shapes.at(name);
if(a.get_shape() != ps) if(a.get_shape() != ps)
{ {
assert(ps.bytes() == a.get_shape().bytes()); assert(ps.bytes() <= a.get_shape().bytes());
return std::make_pair(name, a.reshape(ps)); return std::make_pair(name, a.reshape(ps));
} }
else else
......
...@@ -120,12 +120,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -120,12 +120,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if(not std::all_of( if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { visit_tensor_size(s.ndim(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); })); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
}); });
} }
...@@ -133,12 +131,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -133,12 +131,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); }); visit_tensor_size(s.ndim(), [&](auto ndim) { v(f(xs, ndim)...); });
} }
template <class F> template <class F>
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/dyn_output.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -112,7 +113,7 @@ struct hip_copy_to_gpu ...@@ -112,7 +113,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2).same_type(); check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -121,6 +122,10 @@ struct hip_copy_to_gpu ...@@ -121,6 +122,10 @@ struct hip_copy_to_gpu
if(args.size() == 1) if(args.size() == 1)
return input; return input;
argument result = args[1].share(); argument result = args[1].share();
if(result.get_shape().dynamic())
{
result = result.reshape(args[0].get_shape());
}
gpu_copy(ctx, input, result); gpu_copy(ctx, input, result);
// Associate the input since it was registered with hip // Associate the input since it was registered with hip
return {result.get_shape(), [input, result]() mutable { return result.data(); }}; return {result.get_shape(), [input, result]() mutable { return result.data(); }};
...@@ -138,19 +143,24 @@ struct hip_copy_from_gpu ...@@ -138,19 +143,24 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; } std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2).same_type(); check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const dyn_output& dyn_out, const std::vector<argument>& args) const
{ {
if(args.size() == 1) if(args.size() == 1)
{ {
argument result = allocate_gpu(output_shape, true); argument result = allocate_gpu(dyn_out.computed_shape, true);
gpu_copy(ctx, args[0], result); gpu_copy(ctx, args[0], result);
return result; return result;
} }
copy_from_gpu(ctx, args[0], args[1]); argument input = args[0].share();
if(input.get_shape().dynamic())
{
input = input.reshape(args[1].get_shape());
}
copy_from_gpu(ctx, input, args[1]);
return args[1]; return args[1];
} }
std::ptrdiff_t output_alias(const std::vector<shape>& args) const std::ptrdiff_t output_alias(const std::vector<shape>& args) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment