Commit f660152c authored by superhao1995's avatar superhao1995 Committed by Frank Lee
Browse files

[NFC] polish colossalai/nn/layer/parallel_3d/_operation.py code style (#1258)


Co-authored-by: default avatarResearch <research@soccf-snr3-017.comp.nus.edu.sg>
parent 9738fb0f
......@@ -326,10 +326,8 @@ def split_batch_3d(input_: Tensor,
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, weight_world_size,
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, input_world_size,
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment