Commit dac5563d authored by umangyadav's avatar umangyadav
Browse files

Merge remote-tracking branch 'upstream/develop' into resnet50_partition

parents 95f2cdb9 7e2a550c
......@@ -2190,6 +2190,32 @@ TEST_CASE(prefix_scan_sum)
}
}
TEST_CASE(prefix_scan_sum_dyn)
{
{
std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}};
migraphx::shape s{migraphx::shape::float_type, dd};
expect_shape(
s,
migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", 0}, {"reverse", 0}}),
s);
}
}
TEST_CASE(prefix_scan_sum_dyn_2d)
{
{
std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}, {3, 7}};
migraphx::shape s{migraphx::shape::float_type, dd};
expect_shape(
s,
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", 0}, {"reverse", 0}}),
s);
}
}
TEST_CASE(quant_convolution_shape)
{
migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
......
......@@ -5886,6 +5886,29 @@ TEST_CASE(prefix_scan_sum_1d)
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_dyn_1d)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}),
input);
p.compile(migraphx::make_target("ref"));
std::vector<float> a = {1, 2, 3, 4, 5, 6};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {6}};
migraphx::parameter_map params0;
params0["X"] = migraphx::argument(input_fixed_shape0, a.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0, 3.0, 6.0, 10.0, 15.0, 21.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_2d)
{
{
......
......@@ -70,7 +70,7 @@ struct ${struct_name}
{
using std::swap;
auto * derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -181,7 +181,7 @@ private:
private_detail_te_handle_base_type & private_detail_te_get_handle ()
{
assert(private_detail_te_handle_mem_var != nullptr);
if (not private_detail_te_handle_mem_var.unique())
if (private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment