Unverified Commit f1ecad75 authored by jungpark-mlir's avatar jungpark-mlir Committed by GitHub
Browse files

Add relaxed standard shape assertion (#1416)

Reiterate the assertion on the standard shape but relax it for the multibroadcast ops deliberately inserted to explicit the broadcast.
parent ef22e8b1
...@@ -196,6 +196,7 @@ struct mlir_program ...@@ -196,6 +196,7 @@ struct mlir_program
MlirType make_tensor(const shape& s) const MlirType make_tensor(const shape& s) const
{ {
assert(s.standard());
std::vector<int64_t> lens(s.lens().begin(), s.lens().end()); std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
return mlirRankedTensorTypeGet( return mlirRankedTensorTypeGet(
lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull()); lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull());
...@@ -371,7 +372,11 @@ struct mlir_program ...@@ -371,7 +372,11 @@ struct mlir_program
mlir_operation_state& add_results(const std::vector<shape>& outputs) mlir_operation_state& add_results(const std::vector<shape>& outputs)
{ {
auto x = prog->make_tensors(outputs); std::vector<shape> reshaped(outputs.size());
std::transform(outputs.begin(), outputs.end(), reshaped.begin(), [](const shape& r) {
return shape{r.type(), r.lens()};
});
auto x = prog->make_tensors(reshaped);
mlirOperationStateAddResults(&op_state, x.size(), x.data()); mlirOperationStateAddResults(&op_state, x.size(), x.data());
return *this; return *this;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment