Unverified Commit c463f8ad authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[tensor] remove gpc in tensor tests (#1186)

parent 372f7914
from .process_group import ProcessGroup
from .tensor_spec import TensorSpec from .tensor_spec import TensorSpec
from .compute_spec import ComputeSpec, ComputePattern from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor from .colo_tensor import ColoTensor
...@@ -6,7 +7,6 @@ from .utils import convert_parameter, named_params_with_colotensor ...@@ -6,7 +7,6 @@ from .utils import convert_parameter, named_params_with_colotensor
from .dist_spec_mgr import DistSpecManager from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, ParamOpHookManager from .param_op_hook import ParamOpHook, ParamOpHookManager
from . import distspec from . import distspec
from .process_group import ProcessGroup
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
......
...@@ -30,7 +30,7 @@ class ColoTensor(torch.Tensor): ...@@ -30,7 +30,7 @@ class ColoTensor(torch.Tensor):
1. directly init. 1. directly init.
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate()) >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor. >>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), >>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
>>> dims=[0], >>> dims=[0],
>>> num_partitions=[world_size]) >>> num_partitions=[world_size])
>>> tensor_spec = TensorSpec(shard_spec) >>> tensor_spec = TensorSpec(shard_spec)
......
...@@ -5,7 +5,7 @@ from typing import List, Optional ...@@ -5,7 +5,7 @@ from typing import List, Optional
class ProcessGroup: class ProcessGroup:
""" """
Process Group contains group partition for Tensor Parallel and Data Parallel. Process Group contains group partition for Tensor Parallel and Data Parallel.
WARNING, the ProcessGroup must be used after torch.distributed.initialize() NOTE, the ProcessGroup must be used after torch.distributed.initialize()
args: args:
rank: the global rank of the current process. rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group. ranks: List[int], a list of rank id belongings to this process group.
...@@ -15,16 +15,24 @@ class ProcessGroup: ...@@ -15,16 +15,24 @@ class ProcessGroup:
""" """
def __init__(self, def __init__(self,
rank: int, rank: Optional[int] = None,
ranks: List[int], ranks: Optional[List[int]] = None,
backend: str = 'nccl', backend: str = 'nccl',
tp_degree: Optional[int] = None, tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None: dp_degree: Optional[int] = None) -> None:
self._rank = rank assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
self._rank_list = ranks if rank is None:
self._rank = torch.distributed.get_rank()
else:
self._rank = rank
if ranks is None:
self._rank_list = list(range(torch.distributed.get_world_size()))
else:
self._rank_list = ranks
self._backend = backend self._backend = backend
self._world_size = len(self._rank_list) self._world_size = len(self._rank_list)
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
if dp_degree is None and tp_degree is None: if dp_degree is None and tp_degree is None:
self._dp_degree = self._world_size self._dp_degree = self._world_size
......
...@@ -11,11 +11,9 @@ from colossalai.utils import free_port ...@@ -11,11 +11,9 @@ from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \ from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColoOptimizer from colossalai.nn.optimizer import ColoOptimizer
from functools import partial from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_shard_equal, set_seed
def init_1d_row_linear(weight, pg: ProcessGroup): def init_1d_row_linear(weight, pg: ProcessGroup):
...@@ -50,7 +48,7 @@ def run_1d_hybrid_tp(model_name): ...@@ -50,7 +48,7 @@ def run_1d_hybrid_tp(model_name):
# A simple net with two stacked nn.Linear # A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) rank = torch.distributed.get_rank()
set_seed(1) set_seed(1)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
...@@ -65,9 +63,9 @@ def run_1d_hybrid_tp(model_name): ...@@ -65,9 +63,9 @@ def run_1d_hybrid_tp(model_name):
for p1, p2 in zip(model.parameters(), model_torch.parameters()): for p1, p2 in zip(model.parameters(), model_torch.parameters()):
p2.data.copy_(p1.data) p2.data.copy_(p1.data)
rank = gpc.get_local_rank(ParallelMode.GLOBAL) rank = torch.distributed.get_rank()
world_size = gpc.get_world_size(ParallelMode.GLOBAL) world_size = torch.distributed.get_world_size()
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
if 'bert' == model_name: if 'bert' == model_name:
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if not isinstance(p, ColoTensor): if not isinstance(p, ColoTensor):
...@@ -214,14 +212,14 @@ def run_1d_row_tp(model_name: str): ...@@ -214,14 +212,14 @@ def run_1d_row_tp(model_name: str):
# A simple net with two stacked nn.Linear # A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) rank = torch.distributed.get_rank()
set_seed(1) set_seed(1)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
rank = gpc.get_local_rank(ParallelMode.GLOBAL) rank = torch.distributed.get_rank()
world_size = gpc.get_world_size(ParallelMode.GLOBAL) world_size = torch.distributed.get_world_size()
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
set_seed(1) set_seed(1)
...@@ -243,8 +241,8 @@ def run_1d_row_tp(model_name: str): ...@@ -243,8 +241,8 @@ def run_1d_row_tp(model_name: str):
data = data.to(get_current_device()) data = data.to(get_current_device())
label = label.to(get_current_device()) label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
# Bcast rank0 data to all processes # Bcast rank0 data to all processes
if criterion: if criterion:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment