Unverified Commit 4e24e65a authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Throw on calling `shape.lens()` or `shape.strides()` on a dynamic shape and vice versa (#1937)

Throwing on these calls catches dynamic shape errors earlier rather than having to backpedal from a bad call
parent 7f7998ba
......@@ -150,8 +150,8 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
auto c_type = compute_common_types(input_shapes);
auto c_dyn_dims = compute_common_dyn_dims(input_shapes);
// following should work for a static or dynamic shape
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
auto s0 = inputs[0]->get_shape();
if(not s0.dynamic() or s0.dyn_dims() != c_dyn_dims)
{
inputs[0] = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
......@@ -159,7 +159,8 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
if(input->get_shape().dyn_dims() != c_dyn_dims)
auto s = input->get_shape();
if(not s.dynamic() or s.dyn_dims() != c_dyn_dims)
{
return m.insert_instruction(
ins,
......
......@@ -69,7 +69,7 @@ struct multibroadcast
auto make_bcast_strides = [&](std::vector<std::size_t> bcast_lens, std::size_t offset) {
std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
for(std::ptrdiff_t i = s0.ndim() - 1; i >= 0; i--)
{
if(bcast_lens[i + offset] == s0.lens()[i])
{
......@@ -84,13 +84,13 @@ struct multibroadcast
if(s0.dynamic())
MIGRAPHX_THROW(
"MULTIBROADCAST: Single dynamic input shape not supported. Use two inputs.");
if(s0.lens().size() > output_lens.size())
if(s0.ndim() > output_lens.size())
{
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
}
auto offset = output_lens.size() - s0.lens().size();
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
auto offset = output_lens.size() - s0.ndim();
for(std::ptrdiff_t i = s0.ndim() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1)
{
......@@ -119,7 +119,7 @@ struct multibroadcast
{
// output_lens will not be set for 2+ input version
auto bcast_lens = compute_common_lens(inputs);
auto offset = bcast_lens.size() - s0.lens().size();
auto offset = bcast_lens.size() - s0.ndim();
auto bcast_strides = make_bcast_strides(bcast_lens, offset);
return {t, std::move(bcast_lens), std::move(bcast_strides)};
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -240,6 +240,10 @@ struct MIGRAPHX_EXPORT shape
template <class Iterator>
std::size_t index(Iterator start, Iterator last) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(std::distance(start, last) <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(start, last, this->strides().begin(), std::size_t{0}); // NOLINT
......
......@@ -79,13 +79,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto x = args[0];
auto scale = args[1];
auto bias = args[2];
auto dims = x->get_shape().lens();
if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double).");
bool dyn_input = x->get_shape().dynamic();
auto ndims = x->get_shape().ndim();
auto ndims = x->get_shape().ndim();
assert(ndims >= 2);
auto kdims = ndims - 2;
std::vector<int64_t> axes(kdims);
......@@ -102,6 +100,12 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
(dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean";
if(dtype == shape::half_type and not convert_fp16)
{
if(x->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_INSTANCENORM: half type not supported with dynamic shape "
"unless convert_fp16 is TRUE");
}
auto dims = x->get_shape().lens();
double n =
std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) {
return i * j;
......@@ -122,13 +126,14 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
// both scale and bias.
instruction_ref scale_bcast;
instruction_ref bias_bcast;
if(dyn_input)
if(x->get_shape().dynamic())
{
scale_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), scale, x);
bias_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), bias, x);
}
else
{
auto dims = x->get_shape().lens();
scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
bias_bcast =
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -273,9 +273,23 @@ shape shape::from_permutation(type_t t,
shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::lens() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: lens() called on a dynamic shape");
}
return impl->m_lens;
}
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
const std::vector<std::size_t>& shape::strides() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: strides() called on a dynamic shape");
}
return impl->m_strides;
}
std::size_t shape::ndim() const
{
......@@ -535,7 +549,14 @@ bool shape::any_of_dynamic() const
});
}
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const
{
if(not this->dynamic())
{
MIGRAPHX_THROW("SHAPE: dyn_dims() called on a static shape");
}
return impl->m_dyn_dims;
}
std::vector<std::size_t> shape::min_lens() const
{
......@@ -679,12 +700,22 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s)
{
value result;
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
v = result;
result["type"] = migraphx::to_value(s.type_string());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
// avoid calling functions that will throw
if(s.dynamic())
{
result["lens"] = {};
result["strides"] = {};
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
}
else
{
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["dynamic_dimensions"] = {};
}
v = result;
}
void migraphx_from_value(const value& v, shape& s)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -228,6 +228,15 @@ TEST_CASE(test_shape_dynamic_errors)
EXPECT(test::throws([&] { s.index(std::vector<std::size_t>{0, 1}); }));
EXPECT(test::throws([&] { s.with_lens({3, 5}); }));
EXPECT(test::throws([&] { s.with_lens(shape::float_type, {3, 5}); }));
EXPECT(test::throws([&] { s.lens(); }));
EXPECT(test::throws([&] { s.strides(); }));
}
TEST_CASE(test_shape_static_dyn_dim_error)
{
using migraphx::shape;
migraphx::shape s{shape::float_type, {2, 3, 4}};
EXPECT(test::throws([&] { s.dyn_dims(); }));
}
TEST_CASE(test_shape_dynamic_serialize)
......
......@@ -88,10 +88,31 @@ inline void compile_check(migraphx::program& p,
auto num = shapes.size();
for(std::size_t i = 0; i < num; ++i)
{
if(p.get_output_shapes()[i].lens() != shapes[i].lens())
auto output_shape = p.get_output_shapes()[i];
if(output_shape.dynamic() and shapes[i].dynamic())
{
if(output_shape.dyn_dims() != shapes[i].dyn_dims())
{
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name +
" alters its dynamic output dimensions");
}
}
else if(not(output_shape.dynamic() or shapes[i].dynamic()))
{
if(output_shape.lens() != shapes[i].lens())
{
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name +
" alters its static output dimensions");
}
}
else
{
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name + " alters its shape");
throw std::runtime_error(
"Compiling program with " + name +
" alters its output dimensions (static shape vs dynamic shape)");
}
}
if(t.name() != "ref")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment