Commit 2d1ebf8f authored by dongcl's avatar dongcl
Browse files

fix flux bug

parent f25c421e
...@@ -581,7 +581,7 @@ class LinearRS(torch.autograd.Function): ...@@ -581,7 +581,7 @@ class LinearRS(torch.autograd.Function):
wgrad_compute = False wgrad_compute = False
if wgrad: if wgrad:
if ctx.sequence_parallel if ctx.sequence_parallel:
dim_size = list(grad_output.size()) dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment