Unverified Commit 4b01da24 authored by Ziyue Jiang's avatar Ziyue Jiang Committed by GitHub
Browse files

[TP] change the check assert in split batch 2d (#772)

parent 846406a0
...@@ -739,11 +739,13 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: ...@@ -739,11 +739,13 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
""" """
dim_size = input_.size(dim) dim_size = input_.size(dim)
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
if world_size <= 1:
return input_
assert dim_size % world_size == 0, \ assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
......
...@@ -770,11 +770,13 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: ...@@ -770,11 +770,13 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
""" """
dim_size = input_.size(dim) dim_size = input_.size(dim)
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
if world_size <= 1:
return input_
assert dim_size % world_size == 0, \ assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment