Commit 53c4b899 authored by charlie's avatar charlie
Browse files

progress

parent c3861fb1
...@@ -45,7 +45,7 @@ void dead_code_elimination::apply(module& m) const ...@@ -45,7 +45,7 @@ void dead_code_elimination::apply(module& m) const
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its a builtin or undefined or identity // Skip instruction with empty shape as output unless its a builtin or undefined or identity
if(i->get_shape().elements() == 0 and i->name().front() != '@' and if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity") i->name() != "undefined" and i->name() != "identity")
continue; continue;
assert(bidistance(m, i, last) > 0); assert(bidistance(m, i, last) > 0);
......
...@@ -100,7 +100,7 @@ struct check_shapes ...@@ -100,7 +100,7 @@ struct check_shapes
assert(end != nullptr); assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->lens().size() < n) if(begin->max_lens().size() < n)
MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
" dimensions"); " dimensions");
} }
...@@ -123,14 +123,14 @@ struct check_shapes ...@@ -123,14 +123,14 @@ struct check_shapes
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.lens(); })) if(!this->same([](const shape& s) { return s.max_lens(); }))
MIGRAPHX_THROW(prefix() + "Dimensions do not match"); MIGRAPHX_THROW(prefix() + "Dimensions do not match");
return *this; return *this;
} }
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.lens().size(); })) if(!this->same([](const shape& s) { return s.max_lens().size(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
......
...@@ -19,7 +19,7 @@ struct select_dependent_type ...@@ -19,7 +19,7 @@ struct select_dependent_type
template <class T, class... Ts> template <class T, class... Ts>
using dependent_type = typename select_dependent_type<T, Ts...>::type; using dependent_type = typename select_dependent_type<T, Ts...>::type;
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens); bool normalize_attributes(operation& op, const shape& s);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -54,9 +54,9 @@ struct convolution ...@@ -54,9 +54,9 @@ struct convolution
{ {
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size(); check_attribute_size();
// dim num of input and attribute should match // num of dims of input and attribute should match
auto input_size = inputs[0].lens().size(); const auto input_size = inputs[0].max_lens().size();
auto padding_size = padding.size(); const auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2)) if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{ {
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
...@@ -64,31 +64,62 @@ struct convolution ...@@ -64,31 +64,62 @@ struct convolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
size_t kdims = input_size - 2; if(weights.dynamic())
if(kdims != this->kdims())
{ {
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size"); MIGRAPHX_THROW("CONVOLUTION: dynamic weights not supported");
} }
const size_t num_spatial_dims = input_size - 2;
if(input.lens().at(1) != (weights.lens().at(1) * group)) if(num_spatial_dims != this->kdims())
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers");
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
for(size_t i = 0; i < kdims; i++)
{ {
auto padding_factor = 2 * padding[i]; MIGRAPHX_THROW("CONVOLUTION: input k-dims does not match attribute size");
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) /
stride[i] +
1)));
} }
return inputs[0].with_lens(output_lens); if(!input.dynamic() and input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
auto calc_output_lens = [this, &weights, &num_spatial_dims, &padding_size](std::vector<std::size_t> lens)
{
std::vector<size_t> ret = {};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; i++)
{
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * num_spatial_dims)
{
// when padding is {x0_begin, x1_begin, ... x0_end , x1_end, ...}
padding_factor = padding[i] + padding[i + num_spatial_dims];
}
ret.push_back(
std::size_t(
std::max<std::ptrdiff_t>(
1,
(lens[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + padding_factor) / stride[i] + 1
)
)
);
}
return ret;
};
if(input.dynamic())
{
std::vector<shape::dynamic_dimension> output_dyn_dims = {input.dyn_dims().at(0), input.dyn_dims().at(1)};
auto min_spatial_dims = calc_output_lens(input.min_lens());
auto max_spatial_dims = calc_output_lens(input.max_lens());
auto opt_spatial_dims = calc_output_lens(input.opt_lens());
for (size_t i = 0; i < num_spatial_dims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
return shape{input.type(), output_dyn_dims};
}
else
{
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto spatial_lens = calc_output_lens(input.lens());
std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x){ output_lens.push_back(x); });
return inputs[0].with_lens(output_lens);
}
} }
size_t kdims() const size_t kdims() const
......
...@@ -114,7 +114,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -114,7 +114,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens()); normalize_attributes(y, inputs[0]);
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
......
...@@ -421,8 +421,8 @@ operation instruction::normalized_operator() const ...@@ -421,8 +421,8 @@ operation instruction::normalized_operator() const
operation o = this->get_operator(); operation o = this->get_operator();
if(this->need_normalization()) if(this->need_normalization())
{ {
auto lens = this->inputs().front()->get_shape().lens(); auto s = this->inputs().front()->get_shape();
if(!normalize_attributes(o, lens)) if(!normalize_attributes(o, s))
return this->get_operator(); return this->get_operator();
} }
return o; return o;
......
...@@ -127,21 +127,22 @@ auto tune_pad_attribute(const value& val) ...@@ -127,21 +127,22 @@ auto tune_pad_attribute(const value& val)
return result; return result;
} }
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) bool normalize_attributes(operation& op, const shape& s)
{ {
bool tuned = false; bool tuned = false;
auto attrs = op.attributes(); auto attrs = op.attributes();
auto val = op.to_value(); auto val = op.to_value();
if(attrs.contains("normalize_padding")) if(attrs.contains("normalize_padding"))
{ {
auto num_dims = s.max_lens().size();
auto padding = val.at(attrs.at("normalize_padding").to<std::string>()); auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size(); auto padding_size = padding.size();
// for now, assume the dimensions to pad start at dim 2 // for now, assume the dimensions to pad start at dim 2
auto padding_start = 2; auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start)) if(padding_size == 2 * (num_dims - padding_start))
tuned = true; tuned = true;
else if(padding_size != (lens.size() - padding_start)) else if(padding_size != (num_dims - padding_start))
MIGRAPHX_THROW("inconsistent padding size"); MIGRAPHX_THROW("inconsistent padding size");
else else
{ {
...@@ -171,7 +172,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -171,7 +172,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>(); axes = val.at("axes").without_key().to_vector<int64_t>();
} }
auto vec = vv.to_vector<int64_t>(); auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens); auto result = tune_attribute(vec, axes, rv.without_key(),s.lens());
val[key] = result; val[key] = result;
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
...@@ -180,7 +181,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -180,7 +181,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else else
{ {
auto num = vv.to<int64_t>(); auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens); auto result = tune_attribute({num}, {num}, rv.without_key(), s.lens());
val[key] = result.front(); val[key] = result.front();
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
......
...@@ -20,9 +20,9 @@ void normalize_ops::apply(module& m) const ...@@ -20,9 +20,9 @@ void normalize_ops::apply(module& m) const
if(inputs.empty()) if(inputs.empty())
continue; continue;
auto lens = inputs[0]->get_shape().lens(); auto s = inputs[0]->get_shape();
migraphx::operation tuned_op = ins->get_operator(); migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, lens)) if(normalize_attributes(tuned_op, s))
{ {
m.replace_instruction(ins, tuned_op, inputs); m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized(); ins->set_normalized();
......
...@@ -850,6 +850,76 @@ TEST_CASE(contiguous_test) ...@@ -850,6 +850,76 @@ TEST_CASE(contiguous_test)
EXPECT(migraphx::verify_range(results_vector, data)); EXPECT(migraphx::verify_range(results_vector, data));
} }
TEST_CASE(conv_dynamic_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, {{1, 100, 0}, {3, 3, 0}, {4, 4, 0}, {4, 4, 0}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape);
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}),
input,
weights);
p.compile(migraphx::ref::target{});
std::vector<float> a = {
2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712,
-0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606,
0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259,
0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051,
-0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158,
0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101,
0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297,
1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946,
0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338,
0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022,
0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792,
-2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896,
0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027,
-0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306};
std::vector<float> c = {
-0.14601797, -0.13000923, 0.06521662, 0.06178288, -0.11083675, 0.10154136, 0.09990512,
0.06030385, -0.11374587, -0.17523311, -0.14344215, 0.17802463, 0.06300922, -0.15325832,
0.07066704, 0.05166031, 0.00615084, -0.02606523, 0.08083995, -0.17913306, 0.0624622,
0.0735731, -0.04198661, -0.0164391, -0.06374192, 0.16569914, 0.10681538, 0.07370754,
0.02802075, 0.00282027, 0.15104802, -0.11084409, -0.00197773, 0.07924436, 0.03528272,
0.04765259, -0.15896152, 0.07917164, 0.12125669, -0.1154705, -0.11999125, 0.12749968,
-0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193,
0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292};
std::vector<float> sol = {-0.20817225,
0.87965256,
0.14958936,
-1.24887264,
-0.06540672,
0.20778663,
0.40456355,
-0.99900877,
0.4917807,
0.1994698,
0.64205718,
0.37798831,
-0.25315839,
0.44276932,
-0.16138598,
0.79344082};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_shape, a.data());
params["W"] = migraphx::argument(weights_shape, c.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
}
TEST_CASE(conv2d_padding_stride_test) TEST_CASE(conv2d_padding_stride_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment