Commit e846ae72 authored by Paul's avatar Paul
Browse files

Format

parent dc71e23d
...@@ -225,8 +225,10 @@ struct gemm_impl ...@@ -225,8 +225,10 @@ struct gemm_impl
b_stride = get_batch_stride(input_shapes[1]); b_stride = get_batch_stride(input_shapes[1]);
c_stride = get_batch_stride(input_shapes[2]); c_stride = get_batch_stride(input_shapes[2]);
d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride; d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride;
num_matrices = std::accumulate( num_matrices = std::accumulate(out_lens.rbegin() + 2,
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rend(),
std::size_t{1},
std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and b_stride == 0)) if(num_matrices == 1 or (num_matrices > 1 and b_stride == 0))
{ {
// If the batch dimension of B is broadcasted, then we can // If the batch dimension of B is broadcasted, then we can
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment