Commit 865e71c3 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'dynamic_prefix_scan' into multinomial_parse

parents 63952fb9 7c034a5e
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,6 +21,13 @@ ...@@ -21,6 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
/**
* Parent struct for prefix scan ops. A prefix scan is a mathematical entity useful
* in parallelizing various computations. Given a list of numbers, a prefix scan
* op returns an equal size list of running totals of the values. Other operations
* besides addition can be supported by child ops.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP #define MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP
...@@ -64,9 +71,9 @@ struct prefix_scan_op : op_name<Derived> ...@@ -64,9 +71,9 @@ struct prefix_scan_op : op_name<Derived>
auto s = inputs.front(); auto s = inputs.front();
if(s.dynamic()) if(s.dynamic())
{ {
return {s.type(), s.max_lens()}; return s;
} }
if(s.broadcasted()) else if(s.broadcasted())
{ {
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
...@@ -76,8 +83,9 @@ struct prefix_scan_op : op_name<Derived> ...@@ -76,8 +83,9 @@ struct prefix_scan_op : op_name<Derived>
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
shape output_shape(dyn_out.computed_shape);
argument result{output_shape}; argument result{output_shape};
auto s = args[0].get_shape(); auto s = args[0].get_shape();
if(s == output_shape) if(s == output_shape)
......
...@@ -2091,6 +2091,19 @@ TEST_CASE(prefix_scan_sum) ...@@ -2091,6 +2091,19 @@ TEST_CASE(prefix_scan_sum)
} }
} }
TEST_CASE(prefix_scan_sum_dyn)
{
{
std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}};
migraphx::shape s{migraphx::shape::float_type, dd};
expect_shape(
s,
migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", 0}, {"reverse", 0}}),
s);
}
}
TEST_CASE(quant_convolution_shape) TEST_CASE(quant_convolution_shape)
{ {
migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}}; migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
......
...@@ -5589,6 +5589,29 @@ TEST_CASE(prefix_scan_sum_1d) ...@@ -5589,6 +5589,29 @@ TEST_CASE(prefix_scan_sum_1d)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(prefix_scan_sum_dyn_1d)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}),
input);
p.compile(migraphx::make_target("ref"));
std::vector<float> a = {1, 2, 3, 4, 5, 6};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {6}};
migraphx::parameter_map params0;
params0["X"] = migraphx::argument(input_fixed_shape0, a.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0, 3.0, 6.0, 10.0, 15.0, 21.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_2d) TEST_CASE(prefix_scan_sum_2d)
{ {
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment