Commit 0cf10d1c authored by yuguo's avatar yuguo
Browse files

fix

parent 7a923605
...@@ -581,7 +581,7 @@ def general_grouped_gemm( ...@@ -581,7 +581,7 @@ def general_grouped_gemm(
workspaces[0].shape[0], workspaces[0].shape[0],
accumulate, accumulate,
use_split_accumulator, use_split_accumulator,
) )[0]
for i in range(num_gemms): for i in range(num_gemms):
out[i].copy_(dw[i]) out[i].copy_(dw[i])
return out, bias, gelu_input return out, bias, gelu_input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment