Unverified Commit dae94657 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into jit-reduce-reg

parents b013d991 56c43445
...@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins,
try try
{ {
shape new_shape = ins->get_operator().compute_shape(inputs, mods); shape new_shape = ins->get_operator().compute_shape(inputs, mods);
// Cannot tell if a dynamic shape will need to be made contiguous
if(new_shape.dynamic())
{
return false;
}
// If the output shape is a standard shape, no need to try its output // If the output shape is a standard shape, no need to try its output
if(new_shape.standard()) if(new_shape.standard())
{ {
...@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f) ...@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
} }
} }
// Perform evaluations in parallel // Perform static contiguous evaluations in parallel
std::vector<argument> literals(const_instructions.size()); std::vector<argument> literals(const_instructions.size());
par_for(const_instructions.size(), 1, [&](const auto i) { par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{}; auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front(); auto prev = const_instructions[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()}); // compute the output contiguous shape from the previous instruction shape
shape computed_shape = c.compute_shape({prev->get_shape()});
const std::vector<argument>& prev_eval = {prev->eval()};
// prev_eval should not be used in make_compute_output_shape() as computed_shape is static
auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval));
literals[i] = c.compute(co_shape, prev_eval);
}); });
// Replace static contiguous operations with a literal
for(size_t i = 0; i < const_instructions.size(); i++) for(size_t i = 0; i < const_instructions.size(); i++)
{ {
auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
......
...@@ -30,23 +30,31 @@ namespace migraphx { ...@@ -30,23 +30,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
T generic_read_file(const std::string& filename) T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0)
{ {
std::ifstream is(filename, std::ios::binary | std::ios::ate); std::ifstream is(filename, std::ios::binary | std::ios::ate);
std::streamsize size = is.tellg(); if(nbytes == 0)
if(size < 1) {
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes = is.tellg();
if(offset > nbytes)
MIGRAPHX_THROW("offset is larger than file size");
nbytes -= offset;
}
if(nbytes < 1)
MIGRAPHX_THROW("Invalid size for: " + filename); MIGRAPHX_THROW("Invalid size for: " + filename);
is.seekg(0, std::ios::beg); is.seekg(offset, std::ios::beg);
T buffer(size, 0); T buffer(nbytes, 0);
if(not is.read(&buffer[0], size)) if(not is.read(&buffer[0], nbytes))
MIGRAPHX_THROW("Error reading file: " + filename); MIGRAPHX_THROW("Error reading file: " + filename);
return buffer; return buffer;
} }
std::vector<char> read_buffer(const std::string& filename) std::vector<char> read_buffer(const std::string& filename, size_t offset, size_t nbytes)
{ {
return generic_read_file<std::vector<char>>(filename); return generic_read_file<std::vector<char>>(filename, offset, nbytes);
} }
std::string read_string(const std::string& filename) std::string read_string(const std::string& filename)
......
...@@ -39,13 +39,22 @@ static literal get_scalar(instruction_ref ins) ...@@ -39,13 +39,22 @@ static literal get_scalar(instruction_ref ins)
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front()); return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape(); const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar())) if(s.elements() != 1 && not(s.scalar()))
return {}; return {};
if(not ins->can_eval()) if(not ins->can_eval())
return {}; return {};
auto e = ins->eval(); auto e = ins->eval();
literal r{}; literal r{};
// needed for bool as visit_at invokes as() which promotes bool to int8
// Without this we'll break type checks for logical ops that are fused.
if(e.get_shape().type() == shape::bool_type)
{
r = literal{e.at<bool>()};
}
else
{
e.visit_at([&](auto x) { r = literal{x}; }); e.visit_at([&](auto x) { r = literal{x}; });
}
return r; return r;
} }
...@@ -56,6 +65,8 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -56,6 +65,8 @@ static void create_pointwise_modules(module_pass_manager& mpm)
{ {
if(not ins->get_operator().attributes().get("pointwise", false)) if(not ins->get_operator().attributes().get("pointwise", false))
continue; continue;
if(ins->get_operator().name() == "layout")
continue;
assert(ins->get_operator().attributes().contains("point_op")); assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++)); auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass(); pm->set_bypass();
......
...@@ -107,6 +107,7 @@ struct argument : raw_data<argument> ...@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
data_t m_data{}; data_t m_data{};
}; };
std::vector<shape> to_shapes(const std::vector<argument>& args);
void migraphx_to_value(value& v, const argument& a); void migraphx_to_value(value& v, const argument& a);
void migraphx_from_value(const value& v, argument& a); void migraphx_from_value(const value& v, argument& a);
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP #define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraphx/permutation.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -197,7 +198,7 @@ struct check_shapes ...@@ -197,7 +198,7 @@ struct check_shapes
*/ */
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(not this->same([](const shape& s) { return s.max_lens().size(); })) if(not this->same([](const shape& s) { return s.ndim(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
...@@ -232,6 +233,19 @@ struct check_shapes ...@@ -232,6 +233,19 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are packed with certain layouts
*/
const check_shapes&
packed_layouts(const std::initializer_list<std::vector<int64_t>>& layouts) const
{
if(not this->all_of([&](const shape& s) {
return s.packed() and contains(layouts, find_permutation(s));
}))
MIGRAPHX_THROW(prefix() + "Shapes are not packed with correct layout");
return *this;
}
/*! /*!
* Check all shapes are packed or broadcasted. * Check all shapes are packed or broadcasted.
*/ */
......
...@@ -36,6 +36,9 @@ struct operation; ...@@ -36,6 +36,9 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1); std::vector<std::size_t> s1);
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1);
shape common_shape(const std::vector<shape>& shapes); shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m, instruction_ref insert_common_op(module& m,
......
...@@ -21,41 +21,55 @@ ...@@ -21,41 +21,55 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#define MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/reflect.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context; struct dyn_output
{
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
struct miopen_batch_norm_inference /**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{ {
op::batch_norm_inference op; F ins_inputs;
template <class Self, class F> operator dyn_output() const
static auto reflect(Self& self, F f)
{ {
return migraphx::reflect(self.op, f); return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
if(ins_shape.dynamic())
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
return dyn_output{ins_shape, ins_shape};
});
} }
std::string name() const { return "gpu::batch_norm_inference"; } operator shape() const
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
} }
}; };
} // namespace gpu template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return {f};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> read_buffer(const std::string& filename); std::vector<char> read_buffer(const std::string& filename, size_t offset = 0, size_t nbytes = 0);
std::string read_string(const std::string& filename); std::string read_string(const std::string& filename);
void write_buffer(const std::string& filename, const char* buffer, std::size_t size); void write_buffer(const std::string& filename, const char* buffer, std::size_t size);
......
...@@ -121,6 +121,8 @@ struct instruction ...@@ -121,6 +121,8 @@ struct instruction
bool can_eval() const; bool can_eval() const;
bool is_undefined() const;
argument eval(bool check_eval = true) const; argument eval(bool check_eval = true) const;
void finalize(context& ctx); void finalize(context& ctx);
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#define MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP #define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#include <string> #include <string>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
...@@ -31,18 +31,17 @@ ...@@ -31,18 +31,17 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
/** /**
* Rewrite batchnorm to a multiply and add. * Transform convolutions to nhwc
*/ */
struct rewrite_batchnorm struct layout_nhwc
{ {
std::string name() const { return "rewrite_batchnorm"; } std::string name() const { return "layout_nhwc"; }
void apply(module& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#endif
...@@ -80,6 +80,7 @@ struct literal : raw_data<literal> ...@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
fill(start, end); fill(start, end);
} }
// Directly copies buffer of x
template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)> template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)>
literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s) literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
{ {
...@@ -107,25 +108,15 @@ struct literal : raw_data<literal> ...@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
std::shared_ptr<char> buffer; std::shared_ptr<char> buffer;
shape m_shape; shape m_shape;
// Keeps the same data ordering as the given container
template <class Iterator> template <class Iterator>
void fill(Iterator start, Iterator end) void fill(Iterator start, Iterator end)
{ {
assert(std::distance(start, end) == m_shape.elements()); assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) { m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get())); auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) { std::copy(start, end, output.begin());
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
}); });
});
}
} }
}; };
......
...@@ -205,6 +205,12 @@ struct module ...@@ -205,6 +205,12 @@ struct module
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string> std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os, print_cpp(std::ostream& os,
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -56,13 +57,21 @@ struct argmax ...@@ -56,13 +57,21 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto lens = inputs[0].lens(); const auto& s0 = inputs[0];
if(s0.dynamic())
{
auto dyn_dims = s0.dyn_dims();
dyn_dims[axis] = {1, 1, 0};
return {shape::int64_type, dyn_dims};
}
else
{
auto lens = s0.lens();
lens[axis] = 1; lens[axis] = 1;
return {shape::int64_type, lens}; return {shape::int64_type, lens};
} }
}
template <class T> template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
...@@ -79,19 +88,18 @@ struct argmax ...@@ -79,19 +88,18 @@ struct argmax
max_index = i; max_index = i;
} }
} }
return max_index; return max_index;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
auto batch_item_num = args.front().get_shape().lens()[axis]; auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(dyn_out.computed_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i); auto data_idx = dyn_out.computed_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num); output[i] = this->calc_argmax(input, data_idx, batch_item_num);
}); });
}); });
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -60,10 +61,19 @@ struct binary : op_name<Derived> ...@@ -60,10 +61,19 @@ struct binary : op_name<Derived>
value attributes() const { return base_attributes(); } value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims(); check_shapes{inputs, static_cast<const Derived&>(*this), true}
.has(2)
.same_type()
.same_dims();
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed()) if(s0.dynamic() or s1.dynamic())
{
if(s0 == s1)
return s0;
MIGRAPHX_THROW("BINARY: " + point_function() + ": fixed-dyn shape for inputs");
}
else if(s0 == s1 and s0.packed())
{ {
return s0; return s0;
} }
...@@ -81,9 +91,9 @@ struct binary : op_name<Derived> ...@@ -81,9 +91,9 @@ struct binary : op_name<Derived>
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
......
...@@ -27,23 +27,30 @@ ...@@ -27,23 +27,30 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This /**
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are * 1 input version:
/// computed from multi-indicies by computing the inner product on the multi-index with the strides. * Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want * broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd * that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is * input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
/// obvious from there that we can negate the effects of a given axis by setting the stride of that * work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
/// axis to zero. * with axis = 0
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D static shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens; std::vector<std::size_t> broadcast_lens = {};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -54,36 +61,86 @@ struct broadcast ...@@ -54,36 +61,86 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto input = inputs.at(0); check_shapes{inputs, *this, true}.has(1, 2);
auto t = input.type(); auto s0 = inputs.at(0);
auto t = s0.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); if(inputs.size() == 1)
// the broacast op is deprecated now, so not handling the negative {
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore // value of axis anymore
if(axis >= broadcast_lens.size()) if(axis >= broadcast_lens.size())
{ {
MIGRAPHX_THROW("BROADCAST : axis is out of range"); MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
" is out of range");
} }
if(broadcast_lens.size() - axis < s0.lens().size())
if(broadcast_lens.size() - axis < input.lens().size())
{ {
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims"); MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
} }
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
if(not std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{ {
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match"); MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
} }
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)}; shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < input.elements()) if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size"); {
// don't think this can occur?
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
}
return output;
}
else
{
// two inputs
auto s1 = inputs.at(1);
if(s0.dynamic())
{
MIGRAPHX_THROW("BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape");
}
if(s0.ndim() != 1)
{
MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) +
", only handle ndim = 1");
}
if(axis >= s1.ndim())
{
MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) +
" is out of range");
}
if(s1.dynamic())
{
s0 = s0.to_dynamic();
if(s0.dyn_dims()[0] != s1.dyn_dims()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
return s1;
}
if(s0.lens()[0] != s1.lens()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
"dimension length (" +
migraphx::to_string(s0.lens()[0]) +
" != " + migraphx::to_string(s1.lens()[axis]) + ")");
}
std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, s1.lens(), std::move(bcast_strides)};
return output; return output;
} }
argument compute(shape output_shape, std::vector<argument> args) const }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -33,11 +33,11 @@ namespace migraphx { ...@@ -33,11 +33,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
// Padding mode is default_ for fixed shape padding.
// same_lower and same_upper used for dynamic padding.
enum padding_mode_t enum padding_mode_t
{ {
default_, // NOLINT default_, // NOLINT
same,
valid,
same_lower, same_lower,
same_upper same_upper
}; };
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -42,19 +43,27 @@ namespace op { ...@@ -42,19 +43,27 @@ namespace op {
struct contiguous struct contiguous
{ {
std::string name() const { return "contiguous"; } std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
if(inputs.front().standard()) auto s0 = inputs.front();
return inputs.front(); if(s0.dynamic() or s0.standard())
auto lens = inputs.at(0).lens(); {
auto t = inputs.at(0).type(); return s0;
}
else
{
const auto& lens = s0.lens();
auto t = s0.type();
return {t, lens}; return {t, lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
assert(output_shape.standard()); assert(dyn_out.computed_shape.standard());
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()); output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
......
...@@ -44,7 +44,7 @@ struct convert : unary<convert> ...@@ -44,7 +44,7 @@ struct convert : unary<convert>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
if(input.dynamic()) if(input.dynamic())
{ {
......
...@@ -43,7 +43,6 @@ struct convolution ...@@ -43,7 +43,6 @@ struct convolution
int group = 1; int group = 1;
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
bool use_dynamic_same_auto_pad = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -52,16 +51,15 @@ struct convolution ...@@ -52,16 +51,15 @@ struct convolution
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.group, "group"), f(self.group, "group"),
f(self.padding_mode, "padding_mode"), f(self.padding_mode, "padding_mode"));
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
} }
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
void check_attribute_size() const void check_attribute_size() const
{ {
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() == dilation.size())) stride.size() != dilation.size())
{ {
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes"); MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
} }
...@@ -76,7 +74,8 @@ struct convolution ...@@ -76,7 +74,8 @@ struct convolution
// num of dims of input and attribute should match // num of dims of input and attribute should match
const auto input_size = inputs[0].max_lens().size(); const auto input_size = inputs[0].max_lens().size();
const auto padding_size = padding.size(); const auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
if(input_size != padding_size / 2 + 2 && input_size != padding_size + 2)
{ {
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
} }
...@@ -93,13 +92,6 @@ struct convolution ...@@ -93,13 +92,6 @@ struct convolution
x_shape.lens().at(1) != (w_shape.lens().at(1) * group)) x_shape.lens().at(1) != (w_shape.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers"); MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
std::vector<op::padding_mode_t> dyn_pad_modes = {op::padding_mode_t::same_upper,
op::padding_mode_t::same_lower};
if(use_dynamic_same_auto_pad and not contains(dyn_pad_modes, padding_mode))
{
MIGRAPHX_THROW("CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode");
}
if(x_shape.dynamic() or w_shape.dynamic()) if(x_shape.dynamic() or w_shape.dynamic())
{ {
return dynamic_compute_shape(x_shape, w_shape); return dynamic_compute_shape(x_shape, w_shape);
...@@ -161,7 +153,7 @@ struct convolution ...@@ -161,7 +153,7 @@ struct convolution
dynamic_shape_push_back(w_shape); dynamic_shape_push_back(w_shape);
const size_t num_spatial_dims = x_shape.max_lens().size() - 2; const size_t num_spatial_dims = x_shape.max_lens().size() - 2;
if(use_dynamic_same_auto_pad) if(padding_mode != default_)
{ {
for(std::size_t i = 0; i < num_spatial_dims; ++i) for(std::size_t i = 0; i < num_spatial_dims; ++i)
{ {
......
...@@ -61,8 +61,8 @@ struct deconvolution ...@@ -61,8 +61,8 @@ struct deconvolution
void check_attribute_size() const void check_attribute_size() const
{ {
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() == dilation.size())) stride.size() != dilation.size())
{ {
MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes"); MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment