"examples/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "7486ed7d3a21ad35c4f465583426b25af6b33c04"
Unverified Commit bf5066fb authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[refactor] refactor ColoTensor's unit tests (#1340)

parent f92c100d
from ._util import *
\ No newline at end of file
import pytest import pytest
from functools import partial from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
......
import pytest import pytest
from functools import partial from functools import partial
from _utils import tensor_shard_equal, set_seed
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -15,7 +13,8 @@ from colossalai.tensor import ColoTensor, ProcessGroup ...@@ -15,7 +13,8 @@ from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from _utils import split_param_row_tp1d, split_param_col_tp1d from tests.test_tensor.common_utils import tensor_shard_equal, check_equal, set_seed, \
split_param_row_tp1d, split_param_col_tp1d
def run_1d_hybrid_tp(model_name): def run_1d_hybrid_tp(model_name):
...@@ -264,7 +263,6 @@ def run_1d_row_tp(model_name: str): ...@@ -264,7 +263,6 @@ def run_1d_row_tp(model_name: str):
def _run_pretrain_load(): def _run_pretrain_load():
from _utils import check_equal
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
set_seed(1) set_seed(1)
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased') model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
......
...@@ -7,7 +7,7 @@ import torch.multiprocessing as mp ...@@ -7,7 +7,7 @@ import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed
import colossalai import colossalai
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
......
...@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensorSpec ...@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensorSpec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from functools import partial from functools import partial
from _utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
class Conv1D(nn.Module): class Conv1D(nn.Module):
......
...@@ -8,7 +8,7 @@ import torch.multiprocessing as mp ...@@ -8,7 +8,7 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func):
......
...@@ -8,7 +8,7 @@ import torch.multiprocessing as mp ...@@ -8,7 +8,7 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
def run_with_spec(spec_init_func, pg: ProcessGroup): def run_with_spec(spec_init_func, pg: ProcessGroup):
......
...@@ -8,7 +8,7 @@ import torch.nn.functional as F ...@@ -8,7 +8,7 @@ import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
def run_with_spec(spec_init_func, split_bias): def run_with_spec(spec_init_func, split_bias):
......
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup
import torch import torch
import pytest import pytest
from _utils import tensor_equal from common_utils import tensor_equal
import colossalai import colossalai
from colossalai.utils import free_port from colossalai.utils import free_port
......
...@@ -8,7 +8,7 @@ from colossalai.utils import free_port ...@@ -8,7 +8,7 @@ from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini import ChunkManager from colossalai.gemini import ChunkManager
from functools import partial from functools import partial
from _utils import tensor_equal, set_seed, tensor_shard_equal from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment