"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "725af3eeeb16f7a348578e19105bed4f4096e0ca"
Unverified Commit 10ef8afd authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[gemini] init genimi individual directory (#754)

parent dcca614e
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory']
\ No newline at end of file
...@@ -5,7 +5,7 @@ from colossalai.utils.cuda import get_current_device ...@@ -5,7 +5,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicy from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
from typing import List from typing import List
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
......
...@@ -22,8 +22,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer ...@@ -22,8 +22,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor) get_gradient_predivide_factor)
......
...@@ -21,7 +21,7 @@ from torch import Tensor ...@@ -21,7 +21,7 @@ from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
class OptimState(Enum): class OptimState(Enum):
......
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
from .zero_hook import ZeroHook from .zero_hook import ZeroHook
__all__ = ['StatefulTensorMgr', 'ZeroHook', 'TensorPlacementPolicyFactory'] __all__ = ['ZeroHook']
\ No newline at end of file \ No newline at end of file
...@@ -9,8 +9,7 @@ from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector ...@@ -9,8 +9,7 @@ from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.engine.ophooks import BaseOpHook from colossalai.engine.ophooks import BaseOpHook
......
...@@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device ...@@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_set_process_memory_fraction from colossalai.utils.memory import colo_set_process_memory_fraction
from colossalai.zero.utils import StatefulTensorMgr from colossalai.gemini import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.utils import free_port from colossalai.utils import free_port
...@@ -14,7 +14,9 @@ from colossalai.testing import rerun_on_exception ...@@ -14,7 +14,9 @@ from colossalai.testing import rerun_on_exception
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from typing import List from typing import List
from functools import partial from functools import partial
from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy
from colossalai.gemini import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
class Net(torch.nn.Module): class Net(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment