Unverified Commit 1a229045 authored by YH's avatar YH Committed by GitHub
Browse files

Add interface for colo tesnor dp size (#3227)

parent 1653063f
...@@ -72,7 +72,7 @@ class ChunkManager: ...@@ -72,7 +72,7 @@ class ChunkManager:
if tensor.numel() > chunk_size: if tensor.numel() > chunk_size:
chunk_size = tensor.numel() chunk_size = tensor.numel()
dp_size = tensor.process_group.dp_world_size() dp_size = tensor.get_dp_world_size()
chunk_size = chunk_size + (-chunk_size % dp_size) chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk( chunk = Chunk(
......
...@@ -138,6 +138,15 @@ class ColoTensor(torch.Tensor): ...@@ -138,6 +138,15 @@ class ColoTensor(torch.Tensor):
def get_tp_world_size(self) -> int: def get_tp_world_size(self) -> int:
return self.process_group.tp_world_size() return self.process_group.tp_world_size()
def get_dp_world_size(self) -> int:
"""get_dp_world_size
get the dp world size of the tensor.
Returns:
int: dp world size
"""
return self.process_group.dp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec): def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec """set_dist_spec
set dist spec and change the payloads. set dist spec and change the payloads.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment