"...include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "bc32a874522a3b1631b664952a9887d45d44b573"
Unverified Commit 1a229045 authored by YH's avatar YH Committed by GitHub
Browse files

Add interface for colo tesnor dp size (#3227)

parent 1653063f
......@@ -72,7 +72,7 @@ class ChunkManager:
if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
dp_size = tensor.process_group.dp_world_size()
dp_size = tensor.get_dp_world_size()
chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk(
......
......@@ -138,6 +138,15 @@ class ColoTensor(torch.Tensor):
def get_tp_world_size(self) -> int:
return self.process_group.tp_world_size()
def get_dp_world_size(self) -> int:
"""get_dp_world_size
get the dp world size of the tensor.
Returns:
int: dp world size
"""
return self.process_group.dp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec
set dist spec and change the payloads.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment