Commit 8acc9bd6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 0f85317e
......@@ -172,32 +172,34 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
check_shapes{input_shapes}.not_broadcasted();
auto a_strides = inputs[0].strides();
auto dim_0 = a_strides.size() - 2;
if (a_strides.size() > 2)
auto dim_0 = a_strides.size() - 2;
if(a_strides.size() > 2)
{
if (!std::all_of(a_strides.begin(), a_strides.begin() + dim_0, [&](auto batch_size) {
return std::all_of(a_strides.begin() + dim_0, a_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
if(!std::all_of(a_strides.begin(), a_strides.begin() + dim_0, [&](auto batch_size) {
return std::all_of(a_strides.begin() + dim_0, a_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
{
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(a_strides) + "} is transposed!");
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(a_strides) +
"} is transposed!");
}
}
auto b_strides = inputs[1].strides();
if (b_strides.size() > 2)
if(b_strides.size() > 2)
{
if (!std::all_of(b_strides.begin(), b_strides.begin() + dim_0, [&](auto batch_size) {
return std::all_of(b_strides.begin() + dim_0, b_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
if(!std::all_of(b_strides.begin(), b_strides.begin() + dim_0, [&](auto batch_size) {
return std::all_of(b_strides.begin() + dim_0, b_strides.end(), [&](auto data_size) {
return batch_size >= data_size;
});
}))
{
MIGRAPHX_THROW("DOT: batch size of b {" + to_string_range(b_strides) + "} is transposed!");
MIGRAPHX_THROW("DOT: batch size of b {" + to_string_range(b_strides) +
"} is transposed!");
}
}
return op.compute_shape(input_shapes);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment