Commit 53c4b899 authored by charlie's avatar charlie
Browse files

progress

parent c3861fb1
......@@ -45,7 +45,7 @@ void dead_code_elimination::apply(module& m) const
if(i == last)
break;
// Skip instruction with empty shape as output unless its a builtin or undefined or identity
if(i->get_shape().elements() == 0 and i->name().front() != '@' and
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity")
continue;
assert(bidistance(m, i, last) > 0);
......
......@@ -100,7 +100,7 @@ struct check_shapes
assert(end != nullptr);
if(begin != end)
{
if(begin->lens().size() < n)
if(begin->max_lens().size() < n)
MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
" dimensions");
}
......@@ -123,14 +123,14 @@ struct check_shapes
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
if(!this->same([](const shape& s) { return s.max_lens(); }))
MIGRAPHX_THROW(prefix() + "Dimensions do not match");
return *this;
}
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
if(!this->same([](const shape& s) { return s.max_lens().size(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this;
}
......
......@@ -19,7 +19,7 @@ struct select_dependent_type
template <class T, class... Ts>
using dependent_type = typename select_dependent_type<T, Ts...>::type;
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens);
bool normalize_attributes(operation& op, const shape& s);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -54,9 +54,9 @@ struct convolution
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size();
// dim num of input and attribute should match
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
// num of dims of input and attribute should match
const auto input_size = inputs[0].max_lens().size();
const auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
......@@ -64,32 +64,63 @@ struct convolution
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
size_t kdims = input_size - 2;
if(kdims != this->kdims())
if(weights.dynamic())
{
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size");
MIGRAPHX_THROW("CONVOLUTION: dynamic weights not supported");
}
const size_t num_spatial_dims = input_size - 2;
if(num_spatial_dims != this->kdims())
{
MIGRAPHX_THROW("CONVOLUTION: input k-dims does not match attribute size");
}
if(input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers");
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
if(!input.dynamic() and input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
for(size_t i = 0; i < kdims; i++)
auto calc_output_lens = [this, &weights, &num_spatial_dims, &padding_size](std::vector<std::size_t> lens)
{
std::vector<size_t> ret = {};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; i++)
{
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
if(padding_size == 2 * num_spatial_dims)
{
// when padding is {x0_begin, x1_begin, ... x0_end , x1_end, ...}
padding_factor = padding[i] + padding[i + num_spatial_dims];
}
ret.push_back(
std::size_t(
std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) /
stride[i] +
1)));
(lens[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + padding_factor) / stride[i] + 1
)
)
);
}
return ret;
};
if(input.dynamic())
{
std::vector<shape::dynamic_dimension> output_dyn_dims = {input.dyn_dims().at(0), input.dyn_dims().at(1)};
auto min_spatial_dims = calc_output_lens(input.min_lens());
auto max_spatial_dims = calc_output_lens(input.max_lens());
auto opt_spatial_dims = calc_output_lens(input.opt_lens());
for (size_t i = 0; i < num_spatial_dims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
return shape{input.type(), output_dyn_dims};
}
else
{
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto spatial_lens = calc_output_lens(input.lens());
std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x){ output_lens.push_back(x); });
return inputs[0].with_lens(output_lens);
}
}
size_t kdims() const
{
......
......@@ -114,7 +114,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
normalize_attributes(y, inputs[0]);
return any_cast<T>(y).normalize_compute_shape(inputs);
}
......
......@@ -421,8 +421,8 @@ operation instruction::normalized_operator() const
operation o = this->get_operator();
if(this->need_normalization())
{
auto lens = this->inputs().front()->get_shape().lens();
if(!normalize_attributes(o, lens))
auto s = this->inputs().front()->get_shape();
if(!normalize_attributes(o, s))
return this->get_operator();
}
return o;
......
......@@ -127,21 +127,22 @@ auto tune_pad_attribute(const value& val)
return result;
}
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
bool normalize_attributes(operation& op, const shape& s)
{
bool tuned = false;
auto attrs = op.attributes();
auto val = op.to_value();
if(attrs.contains("normalize_padding"))
{
auto num_dims = s.max_lens().size();
auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size();
// for now, assume the dimensions to pad start at dim 2
auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start))
if(padding_size == 2 * (num_dims - padding_start))
tuned = true;
else if(padding_size != (lens.size() - padding_start))
else if(padding_size != (num_dims - padding_start))
MIGRAPHX_THROW("inconsistent padding size");
else
{
......@@ -171,7 +172,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>();
}
auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens);
auto result = tune_attribute(vec, axes, rv.without_key(),s.lens());
val[key] = result;
op.from_value(val);
val = op.to_value();
......@@ -180,7 +181,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else
{
auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens);
auto result = tune_attribute({num}, {num}, rv.without_key(), s.lens());
val[key] = result.front();
op.from_value(val);
val = op.to_value();
......
......@@ -20,9 +20,9 @@ void normalize_ops::apply(module& m) const
if(inputs.empty())
continue;
auto lens = inputs[0]->get_shape().lens();
auto s = inputs[0]->get_shape();
migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, lens))
if(normalize_attributes(tuned_op, s))
{
m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized();
......
......@@ -850,6 +850,76 @@ TEST_CASE(contiguous_test)
EXPECT(migraphx::verify_range(results_vector, data));
}
TEST_CASE(conv_dynamic_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, {{1, 100, 0}, {3, 3, 0}, {4, 4, 0}, {4, 4, 0}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape);
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}),
input,
weights);
p.compile(migraphx::ref::target{});
std::vector<float> a = {
2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712,
-0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606,
0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259,
0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051,
-0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158,
0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101,
0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297,
1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946,
0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338,
0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022,
0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792,
-2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896,
0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027,
-0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306};
std::vector<float> c = {
-0.14601797, -0.13000923, 0.06521662, 0.06178288, -0.11083675, 0.10154136, 0.09990512,
0.06030385, -0.11374587, -0.17523311, -0.14344215, 0.17802463, 0.06300922, -0.15325832,
0.07066704, 0.05166031, 0.00615084, -0.02606523, 0.08083995, -0.17913306, 0.0624622,
0.0735731, -0.04198661, -0.0164391, -0.06374192, 0.16569914, 0.10681538, 0.07370754,
0.02802075, 0.00282027, 0.15104802, -0.11084409, -0.00197773, 0.07924436, 0.03528272,
0.04765259, -0.15896152, 0.07917164, 0.12125669, -0.1154705, -0.11999125, 0.12749968,
-0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193,
0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292};
std::vector<float> sol = {-0.20817225,
0.87965256,
0.14958936,
-1.24887264,
-0.06540672,
0.20778663,
0.40456355,
-0.99900877,
0.4917807,
0.1994698,
0.64205718,
0.37798831,
-0.25315839,
0.44276932,
-0.16138598,
0.79344082};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_shape, a.data());
params["W"] = migraphx::argument(weights_shape, c.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
}
TEST_CASE(conv2d_padding_stride_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment