Unverified Commit dae94657 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into jit-reduce-reg

parents b013d991 56c43445
......@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins,
try
{
shape new_shape = ins->get_operator().compute_shape(inputs, mods);
// Cannot tell if a dynamic shape will need to be made contiguous
if(new_shape.dynamic())
{
return false;
}
// If the output shape is a standard shape, no need to try its output
if(new_shape.standard())
{
......@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
}
}
// Perform evaluations in parallel
// Perform static contiguous evaluations in parallel
std::vector<argument> literals(const_instructions.size());
par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
// compute the output contiguous shape from the previous instruction shape
shape computed_shape = c.compute_shape({prev->get_shape()});
const std::vector<argument>& prev_eval = {prev->eval()};
// prev_eval should not be used in make_compute_output_shape() as computed_shape is static
auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval));
literals[i] = c.compute(co_shape, prev_eval);
});
// Replace static contiguous operations with a literal
for(size_t i = 0; i < const_instructions.size(); i++)
{
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
......
......@@ -30,23 +30,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
T generic_read_file(const std::string& filename)
T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0)
{
std::ifstream is(filename, std::ios::binary | std::ios::ate);
std::streamsize size = is.tellg();
if(size < 1)
if(nbytes == 0)
{
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes = is.tellg();
if(offset > nbytes)
MIGRAPHX_THROW("offset is larger than file size");
nbytes -= offset;
}
if(nbytes < 1)
MIGRAPHX_THROW("Invalid size for: " + filename);
is.seekg(0, std::ios::beg);
is.seekg(offset, std::ios::beg);
T buffer(size, 0);
if(not is.read(&buffer[0], size))
T buffer(nbytes, 0);
if(not is.read(&buffer[0], nbytes))
MIGRAPHX_THROW("Error reading file: " + filename);
return buffer;
}
std::vector<char> read_buffer(const std::string& filename)
std::vector<char> read_buffer(const std::string& filename, size_t offset, size_t nbytes)
{
return generic_read_file<std::vector<char>>(filename);
return generic_read_file<std::vector<char>>(filename, offset, nbytes);
}
std::string read_string(const std::string& filename)
......
......@@ -39,13 +39,22 @@ static literal get_scalar(instruction_ref ins)
if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar()))
if(s.elements() != 1 && not(s.scalar()))
return {};
if(not ins->can_eval())
return {};
auto e = ins->eval();
literal r{};
e.visit_at([&](auto x) { r = literal{x}; });
// needed for bool as visit_at invokes as() which promotes bool to int8
// Without this we'll break type checks for logical ops that are fused.
if(e.get_shape().type() == shape::bool_type)
{
r = literal{e.at<bool>()};
}
else
{
e.visit_at([&](auto x) { r = literal{x}; });
}
return r;
}
......@@ -56,6 +65,8 @@ static void create_pointwise_modules(module_pass_manager& mpm)
{
if(not ins->get_operator().attributes().get("pointwise", false))
continue;
if(ins->get_operator().name() == "layout")
continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass();
......
......@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
data_t m_data{};
};
std::vector<shape> to_shapes(const std::vector<argument>& args);
void migraphx_to_value(value& v, const argument& a);
void migraphx_from_value(const value& v, argument& a);
......
......@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraphx/permutation.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
......@@ -197,7 +198,7 @@ struct check_shapes
*/
const check_shapes& same_ndims() const
{
if(not this->same([](const shape& s) { return s.max_lens().size(); }))
if(not this->same([](const shape& s) { return s.ndim(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this;
}
......@@ -232,6 +233,19 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are packed with certain layouts
*/
const check_shapes&
packed_layouts(const std::initializer_list<std::vector<int64_t>>& layouts) const
{
if(not this->all_of([&](const shape& s) {
return s.packed() and contains(layouts, find_permutation(s));
}))
MIGRAPHX_THROW(prefix() + "Shapes are not packed with correct layout");
return *this;
}
/*!
* Check all shapes are packed or broadcasted.
*/
......
......@@ -36,6 +36,9 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1);
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1);
shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m,
......
......@@ -21,41 +21,55 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP
#define MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/reflect.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct dyn_output
{
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
struct miopen_batch_norm_inference
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
op::batch_norm_inference op;
F ins_inputs;
template <class Self, class F>
static auto reflect(Self& self, F f)
operator dyn_output() const
{
return migraphx::reflect(self.op, f);
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
if(ins_shape.dynamic())
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
return dyn_output{ins_shape, ins_shape};
});
}
std::string name() const { return "gpu::batch_norm_inference"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
operator shape() const
{
return shapes.size() - 1;
return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
}
};
} // namespace gpu
template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return {f};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -31,7 +31,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> read_buffer(const std::string& filename);
std::vector<char> read_buffer(const std::string& filename, size_t offset = 0, size_t nbytes = 0);
std::string read_string(const std::string& filename);
void write_buffer(const std::string& filename, const char* buffer, std::size_t size);
......
......@@ -121,6 +121,8 @@ struct instruction
bool can_eval() const;
bool is_undefined() const;
argument eval(bool check_eval = true) const;
void finalize(context& ctx);
......
......@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
......@@ -31,18 +31,17 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct module_pass_manager;
/**
* Rewrite batchnorm to a multiply and add.
* Transform convolutions to nhwc
*/
struct rewrite_batchnorm
struct layout_nhwc
{
std::string name() const { return "rewrite_batchnorm"; }
void apply(module& m) const;
std::string name() const { return "layout_nhwc"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
......@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
fill(start, end);
}
// Directly copies buffer of x
template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)>
literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
{
......@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
std::shared_ptr<char> buffer;
shape m_shape;
// Keeps the same data ordering as the given container
template <class Iterator>
void fill(Iterator start, Iterator end)
{
assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
});
});
}
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
std::copy(start, end, output.begin());
});
}
};
......
......@@ -205,6 +205,12 @@ struct module
void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const;
void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os,
......
......@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -56,12 +57,20 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens();
lens[axis] = 1;
return {shape::int64_type, lens};
check_shapes{inputs, *this, true}.has(1);
const auto& s0 = inputs[0];
if(s0.dynamic())
{
auto dyn_dims = s0.dyn_dims();
dyn_dims[axis] = {1, 1, 0};
return {shape::int64_type, dyn_dims};
}
else
{
auto lens = s0.lens();
lens[axis] = 1;
return {shape::int64_type, lens};
}
}
template <class T>
......@@ -79,19 +88,18 @@ struct argmax
max_index = i;
}
}
return max_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
par_for(dyn_out.computed_shape.elements(), [&](auto i) {
auto data_idx = dyn_out.computed_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num);
});
});
......
......@@ -28,6 +28,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -60,10 +61,19 @@ struct binary : op_name<Derived>
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims();
check_shapes{inputs, static_cast<const Derived&>(*this), true}
.has(2)
.same_type()
.same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
if(s0.dynamic() or s1.dynamic())
{
if(s0 == s1)
return s0;
MIGRAPHX_THROW("BINARY: " + point_function() + ": fixed-dyn shape for inputs");
}
else if(s0 == s1 and s0.packed())
{
return s0;
}
......@@ -81,9 +91,9 @@ struct binary : op_name<Derived>
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(),
input1.end(),
......
......@@ -27,23 +27,30 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
/// axis to zero.
/**
* 1 input version:
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
* that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
* input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
* work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
* with axis = 0
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D static shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct broadcast
{
uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens;
uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens = {};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -54,36 +61,86 @@ struct broadcast
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto input = inputs.at(0);
auto t = input.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broacast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
check_shapes{inputs, *this, true}.has(1, 2);
auto s0 = inputs.at(0);
auto t = s0.type();
if(inputs.size() == 1)
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
" is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
if(broadcast_lens.size() - axis < input.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims");
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
{
// don't think this can occur?
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
}
return output;
}
if(not std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
else
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
// two inputs
auto s1 = inputs.at(1);
if(s0.dynamic())
{
MIGRAPHX_THROW("BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape");
}
if(s0.ndim() != 1)
{
MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) +
", only handle ndim = 1");
}
if(axis >= s1.ndim())
{
MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) +
" is out of range");
}
if(s1.dynamic())
{
s0 = s0.to_dynamic();
if(s0.dyn_dims()[0] != s1.dyn_dims()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
return s1;
}
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < input.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
return output;
if(s0.lens()[0] != s1.lens()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
"dimension length (" +
migraphx::to_string(s0.lens()[0]) +
" != " + migraphx::to_string(s1.lens()[axis]) + ")");
}
std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, s1.lens(), std::move(bcast_strides)};
return output;
}
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -33,11 +33,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
// Padding mode is default_ for fixed shape padding.
// same_lower and same_upper used for dynamic padding.
enum padding_mode_t
{
default_, // NOLINT
same,
valid,
same_lower,
same_upper
};
......
......@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -42,19 +43,27 @@ namespace op {
struct contiguous
{
std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
if(inputs.front().standard())
return inputs.front();
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return {t, lens};
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.front();
if(s0.dynamic() or s0.standard())
{
return s0;
}
else
{
const auto& lens = s0.lens();
auto t = s0.type();
return {t, lens};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
assert(output_shape.standard());
argument result{output_shape};
assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
......
......@@ -44,7 +44,7 @@ struct convert : unary<convert>
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
auto input = inputs.at(0);
if(input.dynamic())
{
......
......@@ -41,9 +41,8 @@ struct convolution
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1};
int group = 1;
padding_mode_t padding_mode = default_;
bool use_dynamic_same_auto_pad = false;
int group = 1;
padding_mode_t padding_mode = default_;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -52,16 +51,15 @@ struct convolution
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.group, "group"),
f(self.padding_mode, "padding_mode"),
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
f(self.padding_mode, "padding_mode"));
}
std::string name() const { return "convolution"; }
void check_attribute_size() const
{
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() != dilation.size())
{
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
}
......@@ -76,7 +74,8 @@ struct convolution
// num of dims of input and attribute should match
const auto input_size = inputs[0].max_lens().size();
const auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
if(input_size != padding_size / 2 + 2 && input_size != padding_size + 2)
{
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
}
......@@ -93,13 +92,6 @@ struct convolution
x_shape.lens().at(1) != (w_shape.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
std::vector<op::padding_mode_t> dyn_pad_modes = {op::padding_mode_t::same_upper,
op::padding_mode_t::same_lower};
if(use_dynamic_same_auto_pad and not contains(dyn_pad_modes, padding_mode))
{
MIGRAPHX_THROW("CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode");
}
if(x_shape.dynamic() or w_shape.dynamic())
{
return dynamic_compute_shape(x_shape, w_shape);
......@@ -161,7 +153,7 @@ struct convolution
dynamic_shape_push_back(w_shape);
const size_t num_spatial_dims = x_shape.max_lens().size() - 2;
if(use_dynamic_same_auto_pad)
if(padding_mode != default_)
{
for(std::size_t i = 0; i < num_spatial_dims; ++i)
{
......
......@@ -61,8 +61,8 @@ struct deconvolution
void check_attribute_size() const
{
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() != dilation.size())
{
MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment