Commit 8cf7ff08 authored by ver217's avatar ver217
Browse files

polish code

parent e99af94a
import imp
from functools import partial from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
...@@ -20,8 +18,7 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, ...@@ -20,8 +18,7 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
use_memory_tracer=False, use_memory_tracer=False,
shard_strategy=TensorShardStrategy) shard_strategy=TensorShardStrategy)
_ZERO_OPTIMIZER_CONFIG = dict( _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
cpu_offload=False,
initial_scale=2**5, initial_scale=2**5,
min_scale=1, min_scale=1,
growth_factor=2, growth_factor=2,
...@@ -35,7 +32,7 @@ ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), ...@@ -35,7 +32,7 @@ ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict( zero=dict(
model_config=_ZERO_MODEL_CONFIG, model_config=_ZERO_MODEL_CONFIG,
optimizer_config=_ZERO_OPTIMIZER_CONFIG, optimizer_config=_ZERO_OPTIMIZER_CONFIG,
), ),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
CONFIG = dict(fp16=dict(mode=None,), CONFIG = dict(fp16=dict(mode=None,),
......
...@@ -10,8 +10,7 @@ import torch.multiprocessing as mp ...@@ -10,8 +10,7 @@ import torch.multiprocessing as mp
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
...@@ -22,10 +21,10 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd ...@@ -22,10 +21,10 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
@parameterize("enable_autocast", [True]) @parameterize("enable_autocast", [True])
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(enable_autocast, shard_strategy): def run_model_test(enable_autocast, shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy() shard_strategy = shard_strategy_class()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func() model_builder, train_dataloader, _, _, criterion = get_components_func()
......
...@@ -9,8 +9,7 @@ from colossalai.nn.optimizer import CPUAdam ...@@ -9,8 +9,7 @@ from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
...@@ -41,10 +40,10 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): ...@@ -41,10 +40,10 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@parameterize("cpu_offload", [True, False]) @parameterize("cpu_offload", [True, False])
@parameterize("use_cpuadam", [True, False]) @parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy, use_cpuadam): def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy() shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False: if use_cpuadam and cpu_offload is False:
return return
......
...@@ -8,20 +8,21 @@ import colossalai ...@@ -8,20 +8,21 @@ import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize
from common import CONFIG from common import CONFIG
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_zero_state_dict(shard_strategy): def run_zero_state_dict(shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18'] test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy() shard_strategy = shard_strategy_class()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment