Unverified Commit 8cdce039 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[ColoTensor] improves init functions. (#1150)

parent 8106d7b8
...@@ -35,7 +35,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): ...@@ -35,7 +35,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
data: Optional[torch.Tensor] = None, data: Optional[torch.Tensor] = None,
requires_grad: bool = True, requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None: spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec) self._tensor_spec = copy(spec)
self._type = TensorType.MODEL self._type = TensorType.MODEL
self._graph_node = None self._graph_node = None
......
from .op_wrapper import _COLOSSAL_OPS from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy from copy import copy
import torch import torch
from torch.overrides import get_default_nowrap_functions
from colossalai.tensor import TensorSpec from colossalai.tensor import TensorSpec
from .const import TensorType
from colossalai.tensor import distspec from colossalai.tensor import distspec
from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec from colossalai.tensor.distspec import _DistSpec
from torch.overrides import get_default_nowrap_functions
def _convert_output(output): def _convert_output(output):
...@@ -18,34 +19,54 @@ def _convert_output(output): ...@@ -18,34 +19,54 @@ def _convert_output(output):
class ColoTensor(torch.Tensor): class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
1. It contains a torch.Tensor as an attribute. Args:
2. It supports lazy init the tensor's payload. data (torch.Tensor): a torch tensor used as the payload the colotensor.
3. It can hijack the torch functions which using ColoTensors as args to our customized functions. spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()).
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = TensorSpec(shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
""" """
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor': def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
"""__new__
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate())
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad) return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None: def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec) self._tensor_spec = copy(spec)
self._type = TensorType.NONMODEL self._type = TensorType.NONMODEL
self._graph_node = None self._graph_node = None
@property @property
def spec(self) -> TensorSpec: def spec(self) -> TensorSpec:
return self._spec return self._tensor_spec
def set_spec(self, spec: TensorSpec) -> None: def set_spec(self, spec: TensorSpec) -> None:
spec = copy(spec) spec = copy(spec)
self.convert_to_dist_spec_(spec.dist_spec) self._convert_to_dist_spec(spec.dist_spec)
self._spec = spec self._tensor_spec = spec
def has_spec(self) -> bool: def has_spec(self) -> bool:
return self._spec.parallel_action is not None return self._tensor_spec.parallel_action is not None
def is_model_data(self) -> bool: def is_model_data(self) -> bool:
return self._type == TensorType.MODEL return self._type == TensorType.MODEL
...@@ -74,16 +95,16 @@ class ColoTensor(torch.Tensor): ...@@ -74,16 +95,16 @@ class ColoTensor(torch.Tensor):
def is_model_data(self) -> bool: def is_model_data(self) -> bool:
return self._type == TensorType.MODEL return self._type == TensorType.MODEL
def convert_to_dist_spec_(self, dist_spec: _DistSpec) -> None: def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
self._spec.dist_spec = dist_spec self._tensor_spec.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor': def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
spec = copy(self._spec) tensor_spec = copy(self._tensor_spec)
spec.dist_spec = dist_spec tensor_spec.dist_spec = dist_spec
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, spec) return ColoTensor.from_torch_tensor(ret, tensor_spec)
@staticmethod @staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor': def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
......
...@@ -4,6 +4,7 @@ from numpy import prod ...@@ -4,6 +4,7 @@ from numpy import prod
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from packaging import version
# TODO(jiaruifang) circle import, move the divide to colossalai.commons. # TODO(jiaruifang) circle import, move the divide to colossalai.commons.
...@@ -56,6 +57,12 @@ class DistSpecManager: ...@@ -56,6 +57,12 @@ class DistSpecManager:
@staticmethod @staticmethod
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor: def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:
if version.parse(torch.__version__) < version.parse("1.11.0"):
# pytorch lower than 1.11 dose not support gather a cpu tensor.
# Therefore, we transfer tensor to GPU before gather.
saved_dev = tensor.device
tensor.data = tensor.data.cuda()
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())] buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())]
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group) dist.all_gather(buffer, tensor, group=old_dist_spec.process_group)
for i in range(len(old_dist_spec.dims) - 1, -1, -1): for i in range(len(old_dist_spec.dims) - 1, -1, -1):
...@@ -66,6 +73,9 @@ class DistSpecManager: ...@@ -66,6 +73,9 @@ class DistSpecManager:
new_buffer.append(torch.cat(buffer[start:start + num_parts], dim)) new_buffer.append(torch.cat(buffer[start:start + num_parts], dim))
buffer = new_buffer buffer = new_buffer
assert len(buffer) == 1 assert len(buffer) == 1
if version.parse(torch.__version__) < version.parse("1.11.0"):
buffer[0].data = buffer[0].data.to(saved_dev)
return buffer[0] return buffer[0]
@staticmethod @staticmethod
......
...@@ -24,28 +24,13 @@ class ParallelAction(object): ...@@ -24,28 +24,13 @@ class ParallelAction(object):
class TensorSpec(object): class TensorSpec(object):
""" """
It contains two aspects of information: The specification of the ColoTensor.
First, How are tensors distributed in Heterougenous memory space. Args:
Second, if the tensor is a model parameter, the Spec contains the dist_spec (_DistSpec): descriping the layout among processes.
parallel computation pattern of the Operator (Layer). parallel_action (Optional[ParallelAction], optional): actions conducted on the tensor after initialization if it's a model data tensor.
We have to consider the hybrid parallel mode. Defaults to None.
""" """
# a list of parallel actions.
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
# parallel_action_list = [
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
# ParallelAction(1, ComputePattern.TP1D_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
# ]
# When the ColoTensor is initialized,
# we first splitting tensor according to ParallelAction of ZeRO,
# then splitting tensor according to ParallelAction of TP1D_Linear.
# During Linear computation
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1D_Linear.
# After Linear Op, we split the tensors according to ZeRO.
def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None): def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None):
self.parallel_action = parallel_action self.parallel_action = parallel_action
self.dist_spec = dist_spec self.dist_spec = dist_spec
......
...@@ -3,6 +3,17 @@ import pytest ...@@ -3,6 +3,17 @@ import pytest
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from numpy import allclose from numpy import allclose
import colossalai
from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec
from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec, ColoTensor
from colossalai.context import ParallelMode
from functools import partial
def test_tensor_indexing(): def test_tensor_indexing():
torch_t = torch.randn(2, 3) torch_t = torch.randn(2, 3)
...@@ -25,8 +36,6 @@ def test_wrapped_tensor_func(): ...@@ -25,8 +36,6 @@ def test_wrapped_tensor_func():
# non-func attr # non-func attr
assert t.is_cuda == t_ref.is_cuda assert t.is_cuda == t_ref.is_cuda
# TODO I don't find out a tensor function which returns None.
# return 1 torch.Tensor # return 1 torch.Tensor
t_abs = t.abs() t_abs = t.abs()
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs()) assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs())
...@@ -47,3 +56,41 @@ def test_operand(): ...@@ -47,3 +56,41 @@ def test_operand():
t_res = t + t t_res = t + t
assert torch.allclose(t_ref_res, t_res) assert torch.allclose(t_ref_res, t_res)
#### Test Distributed init a Colotensor
def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5)
print(gpc.get_group(ParallelMode.DATA).size())
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_spec(TensorSpec(dist_spec=distspec.replicate()))
assert t.shape == torch.Size((4 * world_size, 5))
def _run_tensor_replicated_init(world_size):
t_ref = torch.randn(4 * world_size, 5)
t = ColoTensor.from_torch_tensor(t_ref.clone())
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
def run_tensor_init(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_tensor_shard_init(world_size)
_run_tensor_replicated_init(world_size)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def _test_dist_init(world_size):
run_func = partial(run_tensor_init, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
# _test_dist_init(4)
test_new()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment