Unverified Commit 49832b23 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[refactory] add nn.parallel module (#1068)

parent 6754f1b7
......@@ -5,6 +5,3 @@ from .metric import *
from .model import *
from .optimizer import *
from ._ops import *
from .modules import ColoLinear, ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
from .colo_module import ColoModule
from .linear import ColoLinear
from .embedding import ColoEmbedding
\ No newline at end of file
from .data_parallel import ColoDDP, ColoDDPV2
__all__ = ['ColoDDP', 'ColoDDPV2']
......@@ -7,8 +7,6 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.chunk import ChunkManager, TensorState
from colossalai.tensor.param_op_hook import use_param_op_hooks
__all__ = ['ColoDDP', 'ColoDDPV2']
def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
......
from .colo_module import ColoModule
from .linear import ColoLinear
from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
__all__ = [
'ColoModule',
'register_colo_module',
'is_colo_module',
'get_colo_module',
'init_colo_module',
'check_colo_module',
'ColoLinear',
'ColoEmbedding',
]
from typing import Dict
from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec
from .modules import ColoModule
from . import ColoModule
import torch
_COLOSSAL_MODULES: Dict[type, ColoModule] = {}
......
......@@ -11,8 +11,6 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
from .timer import MultiTimer, Timer
from .tensor_detector import TensorDetector
from .model.utils import InsertPostInitMethodToModuleSubClasses
from .model.colo_init_context import ColoInitContext
__all__ = [
'checkpoint',
......@@ -52,6 +50,4 @@ __all__ = [
'disposable',
'colo_set_cpu_memory_capacity',
'colo_get_cpu_memory_capacity',
'InsertPostInitMethodToModuleSubClasses',
'ColoInitContext',
]
......@@ -2,7 +2,7 @@ from .utils import InsertPostInitMethodToModuleSubClasses
import torch
from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.nn import register_colo_module, init_colo_module, \
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding
from torch import nn
......
import torch
import functools
import inspect
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.builder.pipeline import partition_uniform, partition_balanced
from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoTensor
......
import contextlib
import functools
from typing import Optional
from contextlib import AbstractContextManager
import torch
import torch.nn as nn
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
......@@ -12,8 +15,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
from contextlib import AbstractContextManager
from colossalai.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
class ZeroContextConfig(object):
......
......@@ -2,7 +2,7 @@ import torch
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from colossalai.nn.parallel import ColoDDPV2
from colossalai.nn.parallel.data_parallel import ColoDDPV2
from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.core import global_context as gpc
......
import pytest
from colossalai.utils import ColoInitContext
from colossalai.utils.model.colo_init_context import ColoInitContext
from numpy import allclose, require
import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
from colossalai.utils.cuda import get_current_device
......
......@@ -5,14 +5,14 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP
from colossalai.nn.parallel.data_parallel import ColoDDP
def init_1d_row_spec(model):
......
from colossalai.utils import free_port, ColoInitContext, get_current_device
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.tensor import ComputePattern, ParallelAction
from functools import partial
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.nn import init_colo_module
from colossalai.nn.parallel import ColoDDP
from colossalai.nn.parallel.layers import init_colo_module
from colossalai.nn.parallel.data_parallel import ColoDDP
import colossalai
import torch
......
......@@ -5,11 +5,11 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ParallelAction, ColoTensor, DistSpecManager
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
......
......@@ -6,7 +6,7 @@ import torch
import torch.multiprocessing as mp
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.nn import init_colo_module, check_colo_module
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed
import colossalai
......
......@@ -5,14 +5,14 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec, ColoParameter, ChunkManager
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from _utils import tensor_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP, ColoDDPV2
from colossalai.nn.parallel import ColoDDPV2
from colossalai.testing import parameterize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment