Unverified Commit 26b7aac0 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[zero] reorganize zero/gemini folder structure (#3424)

* [zero] refactor low-level zero folder structure

* [zero] fix legacy zero import path

* [zero] fix legacy zero import path

* [zero] remove useless import

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] fix test import path

* [zero] fix test

* [zero] fix circular import

* [zero] update import
parent b09adff7
......@@ -5,9 +5,9 @@ from typing import List
import torch
from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class TrainingPhase(Enum):
......
from enum import Enum
from typing import Optional
from typing import Optional, Union
import torch
from typing import Union
from colossalai.gemini.gemini_context import GeminiMemoryManager
from .gemini_context import GeminiMemoryManager
def sizeof_tensor(tensor: torch.Tensor):
......@@ -19,7 +19,7 @@ class TensorState(Enum):
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
......
import functools
import torch
import types
from colossalai.utils.cuda import get_current_device
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
from time import time
from typing import List
import torch
from colossalai.logging import get_dist_logger
from time import time
from colossalai.utils.cuda import get_current_device
from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
class StatefulTensorMgr(object):
......
......@@ -5,11 +5,12 @@ from typing import List, Optional, Type
import torch
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from .stateful_tensor import StatefulTensor
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
class TensorPlacementPolicy(ABC):
......
from typing import Tuple, Union
import torch
from colossalai.gemini.stateful_tensor import StatefulTensor
from typing import Union, Tuple
from .stateful_tensor import StatefulTensor
def is_storage_empty(tensor: torch.Tensor) -> bool:
......
......@@ -13,10 +13,10 @@ from colossalai.context.singleton_meta import SingletonMeta
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.legacy.sharded_param import ShardedParamV2
@dataclass
......
......@@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from typing import List, Optional
import torch.distributed as dist
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC):
......
......@@ -2,17 +2,18 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten
from colossalai.utils import get_current_device
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy):
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
which will fully utilize network bandwidth.
It is especially useful when sub-module contains bias,
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
which will fully utilize network bandwidth.
It is especially useful when sub-module contains bias,
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small).
"""
......
import torch
import torch.nn.functional as F
from typing import Tuple
import torch
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
......
......@@ -2,11 +2,12 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.shard_utils.commons import get_shard
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class TensorShardStrategy(BaseShardStrategy):
......@@ -27,7 +28,7 @@ class TensorShardStrategy(BaseShardStrategy):
Args:
t (ShardedTensor): a tensor to be sharded.
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
Defaults to None.
"""
if t.is_sharded:
......
from .sharded_model_v2 import ShardedModelV2
__all__ = ['ShardedModelV2']
\ No newline at end of file
__all__ = ['ShardedModelV2']
from typing import Any, Callable, List, Tuple
from typing import Any, Callable, List, Tuple, Union
import torch
import torch.nn.functional as F
from typing import Union
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
def get_gradient_predivide_factor(world_size: int) -> float:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment