Unverified Commit f1ecad75 authored by jungpark-mlir's avatar jungpark-mlir Committed by GitHub
Browse files

Add relaxed standard shape assertion (#1416)

Reiterate the assertion on the standard shape but relax it for the multibroadcast ops deliberately inserted to explicit the broadcast.
parent ef22e8b1
......@@ -196,6 +196,7 @@ struct mlir_program
MlirType make_tensor(const shape& s) const
{
assert(s.standard());
std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
return mlirRankedTensorTypeGet(
lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull());
......@@ -371,7 +372,11 @@ struct mlir_program
mlir_operation_state& add_results(const std::vector<shape>& outputs)
{
auto x = prog->make_tensors(outputs);
std::vector<shape> reshaped(outputs.size());
std::transform(outputs.begin(), outputs.end(), reshaped.begin(), [](const shape& r) {
return shape{r.type(), r.lens()};
});
auto x = prog->make_tensors(reshaped);
mlirOperationStateAddResults(&op_state, x.size(), x.data());
return *this;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment