"examples/wuerstchen/text_to_image/__init__.py" did not exist on "6ab2dd18a4d17d90c92409886ac22a02acf25d7d"
Commit 67048d04 authored by Khalique's avatar Khalique
Browse files

merged onnx file

parents a5d6c29a 803eb3ce
......@@ -783,6 +783,48 @@ struct broadcast
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct multibroadcast
{
std::vector<std::size_t> output_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_lens, "output_lens"));
}
std::string name() const { return "multibroadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto t = inputs.at(0).type();
auto input = inputs.at(0);
if(input.lens().empty())
MIGRAPH_THROW("inputs dimensions should be > 0");
if(input.lens().size() > output_lens.size())
MIGRAPH_THROW("inputs dimensions should <= output size");
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size();
for(int i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] == input.lens()[i])
{
bcast_strides[i + offset] = input.strides()[i];
}
}
return {t, output_lens, bcast_strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct scalar
{
shape scalar_bcast;
......@@ -810,7 +852,9 @@ struct binary
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
auto t = inputs.at(0).type();
auto lens = inputs.at(0).lens();
return {t, lens};
}
};
......
......@@ -49,16 +49,17 @@ struct onnx_parser
onnx_parser()
{
add_generic_op("Add", op::add{});
add_generic_op("Div", op::div{});
add_generic_op("MatMul", op::dot{});
add_generic_op("Mul", op::mul{});
add_generic_op("Relu", op::relu{});
add_generic_op("Sub", op::sub{});
add_generic_op("Sum", op::add{});
// disable dropout for inference
add_generic_op("Dropout", op::identity{});
add_broadcastable_binary_op("Add", op::add{});
add_broadcastable_binary_op("Div", op::div{});
add_broadcastable_binary_op("Mul", op::mul{});
add_broadcastable_binary_op("Sub", op::sub{});
add_broadcastable_binary_op("Sum", op::add{});
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
......@@ -93,12 +94,13 @@ struct onnx_parser
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
template <class T>
void add_generic_op(std::string name, T x)
void add_broadcastable_binary_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() == 2 and contains(attributes, "broadcast"))
if(args.size() != 2)
MIGRAPH_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
......@@ -110,7 +112,51 @@ struct onnx_parser
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l);
}
return prog.add_instruction(x, args);
}
else
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &args[0]->get_shape().lens();
const std::vector<std::size_t>* s1 = &args[1]->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
// Copy the larger vector to output_lens
std::vector<std::size_t> output_lens(s1->size());
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]);
return prog.add_instruction(x, l0, l1);
}
});
}
template <class T>
void add_generic_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
}
......@@ -627,15 +673,17 @@ struct onnx_parser
}
std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(
tensor_dims.begin(), tensor_dims.end(), std::back_inserter(dims), [](auto&& d) {
if(not d.has_dim_value())
{
long default_batch_size = 1; // FIXME
return default_batch_size;
}
return d.dim_value();
});
std::transform(tensor_dims.begin(),
tensor_dims.end(),
std::back_inserter(dims),
[](auto&& d) -> std::size_t {
if(not d.has_dim_value())
{
long default_batch_size = 1; // FIXME
return default_batch_size;
}
return d.dim_value();
});
return {shape_type, dims};
}
};
......
......@@ -487,23 +487,44 @@ TEST_CASE(broadcast_test)
}
TEST_CASE(add_broadcast_test)
{
migraph::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
migraph::shape b_shape{migraph::shape::float_type, {2, 2}};
std::vector<float> b_data{0, -1, -2, -3};
uint64_t axis = 0;
auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::broadcast{axis, l1->get_shape()}, l2);
p.add_instruction(migraph::op::add{}, l1, l3);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
EXPECT(result.get_shape().packed());
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraph::verify_range(results_vector, gold));
{
migraph::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
migraph::shape b_shape{migraph::shape::float_type, {2, 2}};
std::vector<float> b_data{0, -1, -2, -3};
uint64_t axis = 0;
auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::broadcast{axis, l1->get_shape()}, l2);
p.add_instruction(migraph::op::add{}, l1, l3);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
EXPECT(result.get_shape().packed());
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraph::verify_range(results_vector, gold));
}
{
migraph::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
migraph::shape b_shape{migraph::shape::float_type, {2, 2, 1}};
std::vector<float> b_data{0, -1, -2, -3};
auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l1);
auto l4 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l2);
p.add_instruction(migraph::op::add{}, l3, l4);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
EXPECT(result.get_shape().packed());
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraph::verify_range(results_vector, gold));
}
}
TEST_CASE(sub_test)
......
......@@ -145,4 +145,68 @@ TEST_CASE(slice_shape)
migraph::op::slice{{2}, {2}, {10}},
input);
}
TEST_CASE(multibroadcast)
{
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 0, 1}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 1, 0, 0}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {5, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 1, 0}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 0, 0, 0}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 0, 1}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 3, 1}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 1, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 1, 1, 0}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraph::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {}};
throws_shape(migraph::op::multibroadcast{lens}, input);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment