Unverified Commit ede8bfa6 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dynamic allocate (#2079)

Makes a version of allocate that takes in dimensions and allocates a buffer
Going to create a simplify_dynamic_ops compiler pass that will use the use_shape_attr flag
The ONNX op ConstantOfShape needs the buffer to be filled with a specific value, so going to make another PR for that and a fill operator
parent b00489b3
......@@ -36,20 +36,48 @@ namespace op {
struct allocate
{
shape s{};
// for dynamic allocate to set the buffer type
shape::type_t buf_type = shape::half_type;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"));
return pack(f(self.s, "shape"), f(self.buf_type, "buf_type"));
}
std::string name() const { return "allocate"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
migraphx::check_shapes{inputs, *this, true}.has(0);
return s;
migraphx::check_shapes{inputs, *this, true}.has(0, 1);
// check if shape attribute is not default
if(s != shape())
{
return s;
}
else
{
const auto& out_dims = inputs.at(0);
assert(not out_dims.dynamic());
assert(out_dims.ndim() == 1);
std::size_t max_val = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0),
shape::dynamic_dimension{0, max_val});
return {buf_type, dyn_dims};
}
}
argument compute(const shape& output_shape, const std::vector<argument>&) const
argument compute(const shape& output_shape, const std::vector<argument>& args) const
{
return {output_shape};
if(args.empty())
{
return {output_shape};
}
else
{
std::vector<std::size_t> output_dims(output_shape.ndim());
args.at(0).visit([&](auto a) { output_dims.assign(a.begin(), a.end()); });
return {shape{buf_type, output_dims}};
}
}
};
......
......@@ -82,6 +82,33 @@ void throws_shape(const migraphx::shape&, Ts...)
"An expected shape should not be passed to throws_shape function");
}
TEST_CASE(allocate_static)
{
migraphx::shape out_shape{migraphx::shape::float_type, {2, 3, 4}};
expect_shape(out_shape, migraphx::make_op("allocate", {{"shape", to_value(out_shape)}}));
}
TEST_CASE(allocate_dyn)
{
migraphx::shape input{migraphx::shape::int64_type, {2}};
auto max_val = std::numeric_limits<std::size_t>::max();
std::vector<migraphx::shape::dynamic_dimension> dyn_dims(
2, migraphx::shape::dynamic_dimension{0, max_val});
expect_shape(migraphx::shape{migraphx::shape::float_type, dyn_dims},
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}),
input);
}
TEST_CASE(allocate_dyn_with_shape_attr)
{
migraphx::shape input{migraphx::shape::int64_type, {4}};
migraphx::shape shape_attr{migraphx::shape::float_type,
{{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}};
expect_shape(shape_attr,
migraphx::make_op("allocate", {{"shape", migraphx::to_value(shape_attr)}}),
input);
}
TEST_CASE(argmax_axis0)
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(allocate_dyn)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int64_type, {4}};
auto out_dims = mm->add_parameter("out_dims", s);
mm->add_instruction(migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}),
out_dims);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int64_t> data = {2, 3, 4, 4};
params["out_dims"] = migraphx::argument(s, data.data());
auto result = p.eval(params).back();
migraphx::shape sresult{migraphx::shape::float_type, {2, 3, 4, 4}};
result.visit([&](auto output) { EXPECT(output.get_shape() == sresult); });
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment