Unverified Commit 26b7aac0 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[zero] reorganize zero/gemini folder structure (#3424)

* [zero] refactor low-level zero folder structure

* [zero] fix legacy zero import path

* [zero] fix legacy zero import path

* [zero] remove useless import

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] fix test import path

* [zero] fix test

* [zero] fix circular import

* [zero] update import
parent b09adff7
import torch import torch
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from . import BaseOpHook from . import BaseOpHook
......
...@@ -5,9 +5,9 @@ from typing import List ...@@ -5,9 +5,9 @@ from typing import List
import torch import torch
from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class TrainingPhase(Enum): class TrainingPhase(Enum):
......
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Union
import torch import torch
from typing import Union
from colossalai.gemini.gemini_context import GeminiMemoryManager from .gemini_context import GeminiMemoryManager
def sizeof_tensor(tensor: torch.Tensor): def sizeof_tensor(tensor: torch.Tensor):
...@@ -19,7 +19,7 @@ class TensorState(Enum): ...@@ -19,7 +19,7 @@ class TensorState(Enum):
class StatefulTensor(object): class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states. """A Structure stores a Torch Tensor and labeled states.
Inspired from the paper: Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
......
import functools import functools
import torch
import types import types
from colossalai.utils.cuda import get_current_device from time import time
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
from typing import List from typing import List
import torch
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from time import time from colossalai.utils.cuda import get_current_device
from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
class StatefulTensorMgr(object): class StatefulTensorMgr(object):
......
...@@ -5,11 +5,12 @@ from typing import List, Optional, Type ...@@ -5,11 +5,12 @@ from typing import List, Optional, Type
import torch import torch
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from .stateful_tensor import StatefulTensor
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
class TensorPlacementPolicy(ABC): class TensorPlacementPolicy(ABC):
......
from typing import Tuple, Union
import torch import torch
from colossalai.gemini.stateful_tensor import StatefulTensor
from typing import Union, Tuple from .stateful_tensor import StatefulTensor
def is_storage_empty(tensor: torch.Tensor) -> bool: def is_storage_empty(tensor: torch.Tensor) -> bool:
......
...@@ -13,10 +13,10 @@ from colossalai.context.singleton_meta import SingletonMeta ...@@ -13,10 +13,10 @@ from colossalai.context.singleton_meta import SingletonMeta
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.legacy.sharded_param import ShardedParamV2
@dataclass @dataclass
......
...@@ -2,7 +2,8 @@ from abc import ABC, abstractmethod ...@@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional
import torch.distributed as dist import torch.distributed as dist
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC): class BaseShardStrategy(ABC):
......
...@@ -2,17 +2,18 @@ from typing import List, Optional ...@@ -2,17 +2,18 @@ from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten from torch._utils import _flatten_dense_tensors as flatten
from colossalai.utils import get_current_device
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
from .tensor_shard_strategy import TensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy): class BucketTensorShardStrategy(TensorShardStrategy):
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
which will fully utilize network bandwidth. which will fully utilize network bandwidth.
It is especially useful when sub-module contains bias, It is especially useful when sub-module contains bias,
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small). since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small).
""" """
......
import torch
import torch.nn.functional as F
from typing import Tuple from typing import Tuple
import torch
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]: def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor.""" """Return the local shard of a full tensor."""
......
...@@ -2,11 +2,12 @@ from typing import List, Optional ...@@ -2,11 +2,12 @@ from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils.commons import get_shard from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.legacy.shard_utils.commons import get_shard
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class TensorShardStrategy(BaseShardStrategy): class TensorShardStrategy(BaseShardStrategy):
...@@ -27,7 +28,7 @@ class TensorShardStrategy(BaseShardStrategy): ...@@ -27,7 +28,7 @@ class TensorShardStrategy(BaseShardStrategy):
Args: Args:
t (ShardedTensor): a tensor to be sharded. t (ShardedTensor): a tensor to be sharded.
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
Defaults to None. Defaults to None.
""" """
if t.is_sharded: if t.is_sharded:
......
from .sharded_model_v2 import ShardedModelV2 from .sharded_model_v2 import ShardedModelV2
__all__ = ['ShardedModelV2'] __all__ = ['ShardedModelV2']
\ No newline at end of file
from typing import Any, Callable, List, Tuple from typing import Any, Callable, List, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Union
from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
def get_gradient_predivide_factor(world_size: int) -> float: def get_gradient_predivide_factor(world_size: int) -> float:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment