Unverified Commit 4d90a7b5 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[refactor] zero directory (#724)

parent 20ab1f55
...@@ -10,6 +10,8 @@ from .torch_amp import convert_to_torch_amp ...@@ -10,6 +10,8 @@ from .torch_amp import convert_to_torch_amp
from .apex_amp import convert_to_apex_amp from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp from .naive_amp import convert_to_naive_amp
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
"""A helper function to wrap training components with Torch AMP modules. """A helper function to wrap training components with Torch AMP modules.
......
from typing import List, Callable, Optional from .utils import register_ophooks_recursively, BaseOpHook
import torch __all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
from ._base_ophook import BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
from ._shard_grad_ophook import ShardGradHook
from ._shard_param_ophook import ShardParamHook
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
output = output.detach()
ctx.pre_backward_function = pre_backward_function
return output
@staticmethod
def backward(ctx, *args):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx.pre_backward_function(ctx.module)
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
name: str = "",
filter_fn: Optional[Callable] = None):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
# Add hooks for submodules
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
# return from flitered module
if filter_fn is not None and filter_fn(module):
return
if ophook_list is not None:
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_fwd_exec(submodule, *args)
def _post_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)
module.register_forward_hook(_pre_backward_module_hook)
module.register_forward_pre_hook(_post_backward_module_hook)
from abc import ABC, abstractmethod
import torch
class BaseOpHook(ABC):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def __init__(self):
pass
@abstractmethod
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
pass
@abstractmethod
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
@abstractmethod
def post_iter(self):
pass
import torch
from typing import List, Callable, Optional
from abc import ABC, abstractmethod
import torch
class BaseOpHook(ABC):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def __init__(self):
pass
@abstractmethod
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
pass
@abstractmethod
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
@abstractmethod
def post_iter(self):
pass
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
output = output.detach()
ctx.pre_backward_function = pre_backward_function
return output
@staticmethod
def backward(ctx, *args):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx.pre_backward_function(ctx.module)
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
name: str = "",
filter_fn: Optional[Callable] = None):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
# Add hooks for submodules
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
# return from flitered module
if filter_fn is not None and filter_fn(module):
return
if ophook_list is not None:
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_fwd_exec(submodule, *args)
def _post_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)
module.register_forward_hook(_pre_backward_module_hook)
module.register_forward_pre_hook(_post_backward_module_hook)
from .base_shard_strategy import BaseShardStrategy from .base_shard_strategy import BaseShardStrategy
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
from .tensor_shard_strategy import TensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy
from .stateful_tensor_mgr import StatefulTensorMgr
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'StatefulTensorMgr'] __all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
...@@ -3,7 +3,7 @@ from typing import List, Optional ...@@ -3,7 +3,7 @@ from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.commons import get_shard from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
......
...@@ -8,9 +8,8 @@ import torch.nn as nn ...@@ -8,9 +8,8 @@ import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import register_ophooks_recursively from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook from colossalai.zero.utils import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.engine.gradient_handler.utils import bucket_allreduce
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable from colossalai.utils import get_current_device, disposable
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
...@@ -18,12 +17,12 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \ ...@@ -18,12 +17,12 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor) get_gradient_predivide_factor)
......
...@@ -8,22 +8,6 @@ from colossalai.utils import is_model_parallel_parameter ...@@ -8,22 +8,6 @@ from colossalai.utils import is_model_parallel_parameter
import torch.distributed as dist import torch.distributed as dist
def move_tensor(input_, device):
assert device in ['cpu', 'gpu']
if isinstance(input_, (list, tuple)):
for tensor in input_:
tensor.data = tensor.data.cpu(
) if device == 'cpu' else tensor.data.cuda()
elif torch.is_tensor(input_):
input_.data = input_.data.cpu(
) if device == 'cpu' else tensor.data.cuda()
else:
raise TypeError(
f"Expected argument 'input_' to be torch.Tensor, list or tuple, but got {type(input_)} "
)
def flatten(input_): def flatten(input_):
return _flatten_dense_tensors(input_) return _flatten_dense_tensors(input_)
...@@ -51,8 +35,7 @@ def shuffle_by_round_robin(tensor_list, num_partitions): ...@@ -51,8 +35,7 @@ def shuffle_by_round_robin(tensor_list, num_partitions):
partition_to_go = tensor_idx % num_partitions partition_to_go = tensor_idx % num_partitions
if partition_to_go not in partitions: if partition_to_go not in partitions:
partitions[partition_to_go] = [] partitions[partition_to_go] = []
partitions[partition_to_go].append(dict(tensor=tensor, partitions[partition_to_go].append(dict(tensor=tensor, index=tensor_idx))
index=tensor_idx))
partitions_count = len(partitions) partitions_count = len(partitions)
new_tensor_list = [] new_tensor_list = []
...@@ -73,9 +56,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size): ...@@ -73,9 +56,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size):
padding = calculate_padding(num_elements, unit_size=unit_size) padding = calculate_padding(num_elements, unit_size=unit_size)
if padding > 0: if padding > 0:
pad_tensor = torch.zeros(padding, pad_tensor = torch.zeros(padding, device=tensor_list[0].device, dtype=tensor_list[0].dtype)
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor] padded_tensor_list = tensor_list + [pad_tensor]
else: else:
padded_tensor_list = tensor_list padded_tensor_list = tensor_list
...@@ -86,6 +67,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size): ...@@ -86,6 +67,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size):
def is_nccl_aligned(tensor): def is_nccl_aligned(tensor):
return tensor.data_ptr() % 4 == 0 return tensor.data_ptr() % 4 == 0
def get_grad_accumulate_object(tensor): def get_grad_accumulate_object(tensor):
""" """
Return the AccumulateGrad of the input tensor Return the AccumulateGrad of the input tensor
...@@ -108,10 +90,7 @@ def get_grad_accumulate_object(tensor): ...@@ -108,10 +90,7 @@ def get_grad_accumulate_object(tensor):
def split_half_float_double(tensor_list): def split_half_float_double(tensor_list):
dtypes = [ dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
"torch.cuda.HalfTensor", "torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"
]
buckets = [] buckets = []
for i, dtype in enumerate(dtypes): for i, dtype in enumerate(dtypes):
bucket = [t for t in tensor_list if t.type() == dtype] bucket = [t for t in tensor_list if t.type() == dtype]
...@@ -120,10 +99,7 @@ def split_half_float_double(tensor_list): ...@@ -120,10 +99,7 @@ def split_half_float_double(tensor_list):
return buckets return buckets
def reduce_tensor(tensor, def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA):
dtype,
dst_rank=None,
parallel_mode=ParallelMode.DATA):
""" """
Reduce the tensor in the data parallel process group Reduce the tensor in the data parallel process group
...@@ -165,6 +141,7 @@ def reduce_tensor(tensor, ...@@ -165,6 +141,7 @@ def reduce_tensor(tensor,
tensor.copy_(tensor_to_reduce) tensor.copy_(tensor_to_reduce)
return tensor return tensor
def has_inf_or_nan(tensor): def has_inf_or_nan(tensor):
try: try:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
...@@ -181,8 +158,7 @@ def has_inf_or_nan(tensor): ...@@ -181,8 +158,7 @@ def has_inf_or_nan(tensor):
raise raise
return True return True
else: else:
if tensor_sum == float('inf') or tensor_sum == -float( if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
'inf') or tensor_sum != tensor_sum:
return True return True
return False return False
...@@ -201,11 +177,7 @@ def calculate_global_norm_from_list(norm_list): ...@@ -201,11 +177,7 @@ def calculate_global_norm_from_list(norm_list):
return math.sqrt(total_norm) return math.sqrt(total_norm)
def compute_norm(gradients, def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
params,
dp_group,
mp_group,
norm_type=2):
"""Clips gradient norm of an iterable of parameters. """Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that added functionality to handle model parallel parameters. Note that
...@@ -229,14 +201,11 @@ def compute_norm(gradients, ...@@ -229,14 +201,11 @@ def compute_norm(gradients,
if norm_type == inf: if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients) total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
op=torch.distributed.ReduceOp.MAX,
group=dp_group)
# Take max across all GPUs. # Take max across all GPUs.
if mp_group is not None: if mp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
else: else:
total_norm = 0.0 total_norm = 0.0
...@@ -248,21 +217,17 @@ def compute_norm(gradients, ...@@ -248,21 +217,17 @@ def compute_norm(gradients,
if is_model_parallel_parameter(p) or mp_rank == 0: if is_model_parallel_parameter(p) or mp_rank == 0:
param_norm = g.data.double().norm(2) param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2 total_norm += param_norm.item()**2
# Sum across all model parallel GPUs. # Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
op=torch.distributed.ReduceOp.SUM,
group=dp_group)
if mp_group is not None: if mp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM)
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type) total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float( if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1 total_norm = -1
return total_norm return total_norm
......
...@@ -12,8 +12,8 @@ from colossalai.logging import get_dist_logger ...@@ -12,8 +12,8 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone, from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage) colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_optim._utils import has_inf_or_nan
......
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline,
colo_model_data_move_to_cpu, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor
__all__ = ['ShardedTensor', 'ShardedParamV2'] __all__ = [
'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline',
'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor'
]
import torch import torch
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional, Tuple from typing import Optional, Tuple
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
from .tensorful_state import StatefulTensor, TensorState from .tensorful_state import StatefulTensor, TensorState
from typing import List from typing import List
......
from .stateful_tensor_mgr import StatefulTensorMgr
from .zero_hook import ZeroHook
__all__ = ['StatefulTensorMgr', 'ZeroHook']
\ No newline at end of file
...@@ -4,7 +4,7 @@ import types ...@@ -4,7 +4,7 @@ import types
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Dict, List from typing import Dict, List
......
...@@ -3,15 +3,16 @@ from typing import Optional ...@@ -3,15 +3,16 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from ._base_ophook import BaseOpHook
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline from colossalai.engine.ophooks import BaseOpHook
@OPHOOKS.register_module @OPHOOKS.register_module
......
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_on_exception
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
......
import pytest import pytest
import colossalai
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage, colo_model_data_tensor_move, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_tensor_clone from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone)
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
import colossalai
import torch import torch
......
...@@ -30,10 +30,9 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio) ...@@ -30,10 +30,9 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext( with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(get_current_device()), shard_strategy=shard_strategy,
shard_strategy=shard_strategy, shard_param=True):
shard_param=True):
zero_model = model_builder(checkpoint=True) zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2( zero_model = ShardedModelV2(
zero_model, zero_model,
......
...@@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device ...@@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.shard_utils import StatefulTensorMgr from colossalai.zero.utils import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.utils import free_port from colossalai.utils import free_port
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment