Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
from .activation_function import *
from .arithmetic import *
from .convolution import *
from .embedding import *
from .normalization import *
from .python_ops import *
from .torch_ops import *
from .convolution import *
\ No newline at end of file
import torch
from ..registry import meta_patched_function
from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.relu)
......
import torch
from ..registry import meta_patched_function
from ...registry import meta_patched_function
@meta_patched_function.register(torch.matmul)
......@@ -57,16 +57,36 @@ def torch_bmm(input, mat2, *, out=None):
return torch.empty(batch_size, n, p, device="meta")
@meta_patched_function.register(torch.nn.functional.linear)
def torch_linear(input, mat2, bias=None, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
output_shape = list(input.shape)
output_feature = list(mat2.shape)[0]
output_shape[-1] = output_feature
return torch.empty(*output_shape, device="meta")
@meta_patched_function.register(torch.addbmm)
@meta_patched_function.register(torch.Tensor.addbmm)
def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
batch_size, n, m = mat1.shape
_, n, _ = mat1.shape
_, _, p = mat2.shape
return torch.empty(n, p, device="meta")
@meta_patched_function.register(torch.addmm)
@meta_patched_function.register(torch.Tensor.addmm)
def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
n, _ = mat1.shape
_, p = mat2.shape
return torch.empty(n, p, device="meta")
@meta_patched_function.register(torch.var_mean)
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
assert out is None, 'saving to out is not supported yet'
......
import torch
import collections
from itertools import repeat
from ..registry import meta_patched_function
import math
from itertools import repeat
import torch
from ...registry import meta_patched_function
def _ntuple(n, name="parse"):
......
import torch
from ..registry import meta_patched_function
from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.embedding)
......
import torch
from ..registry import meta_patched_function
from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.layer_norm)
......
import operator
import torch
from ..registry import meta_patched_function
from colossalai.fx.proxy import ColoProxy
from ...registry import meta_patched_function
@meta_patched_function.register(operator.getitem)
def operator_getitem(a, b):
......
import torch
from ..registry import meta_patched_function
from ...registry import meta_patched_function
@meta_patched_function.register(torch.arange)
......
import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.ReLU)
......
import math
import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Conv1d)
......
import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Embedding)
......
import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Linear)
......
import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.LayerNorm)
......
import math
import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.AvgPool1d)
......
import torch
from ..registry import meta_patched_module
from typing import Optional
import torch
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.GRU)
@meta_patched_module.register(torch.nn.RNN)
......
......@@ -23,3 +23,6 @@ class PatchRegistry:
meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution')
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition')
......@@ -5,22 +5,29 @@ tracer.py:
The implementation is partly inspired HuggingFace's fx tracer
"""
import enum
import inspect
import functools
import inspect
import operator
from contextlib import contextmanager
from colossalai.fx.tracer.meta_patch import meta_patched_module
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.fx import Tracer, Node
from torch.fx.graph import Graph
from torch.fx.proxy import Proxy, ParameterProxy
from torch.fx import Node, Tracer
from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods
from torch.fx.proxy import ParameterProxy, Proxy
from ..proxy import ColoProxy
from typing import Optional, Dict, Any
from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
from .meta_patch import meta_patched_function, meta_patched_module
from torch.fx.graph import magic_methods, reflectable_magic_methods
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from .registry import (
bias_addition_function,
bias_addition_method,
bias_addition_module,
meta_patched_function,
meta_patched_module,
)
__all__ = ['ColoTracer']
......@@ -77,54 +84,42 @@ class ColoTracer(Tracer):
"""
Create a proxy for different kinds of operations.
"""
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
if self.tracer_type == TracerType.DEFAULT:
# since meta_args is not given
# we just fall back to the original torch.fx.Tracer
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
return proxy
proxy: ColoProxy
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
proxy.meta_data = self.meta_args[target]
return proxy
if target in self.orig_torch_tensor_methods:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if "device" in kwargs:
kwargs["device"] = "meta"
try:
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
# if graph is traced for auto parallelism module, some extra node will be added during
# graph construction to deal with the compatability between bias addition and all reduce.
# if no extra manipulation is applied, we just pass the origin arguments to create_proxy function
# to create node on computation graph
origin_arguments = (kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
# dispatch the arguments generator depending on the kind and target in origin arguments.
args_metas, _ = extract_meta(*args, **kwargs)
handle = None
if kind == "call_function":
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
meta_target = meta_patched_function.get(target.__name__)
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
else:
meta_target = target
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute)
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute)
# fetch patched method
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
else:
meta_target = method
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
......@@ -132,33 +127,26 @@ class ColoTracer(Tracer):
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if meta_patched_module.has(mod_type):
meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas)
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(self, target, args, kwargs, function_to_substitute)
finally:
self._disable_module_getattr = False
elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
finally:
self._disable_module_getattr = False
else:
return proxy
if not isinstance(proxy, Proxy):
raise ValueError("Don't support composite output yet")
if handle is not None:
return handle.generate()
# create nodes using patched arguments
proxy = super().create_proxy(*origin_arguments)
proxy: ColoProxy
meta_out = self._meta_data_computing(
kind,
target,
args,
kwargs,
)
proxy.meta_data = meta_out
except Exception as e:
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
return proxy
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
......@@ -222,6 +210,105 @@ class ColoTracer(Tracer):
else:
raise ValueError(f"Unrecognised tracer type {tracer_type}")
def _meta_data_computing(self, kind, target, args, kwargs):
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
meta_out = self.meta_args[target]
return meta_out
if target in self.orig_torch_tensor_methods:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if "device" in kwargs:
kwargs["device"] = "meta"
try:
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
if kind == "call_function":
# Our meta data will not record the nn.parameter.Parameter attribute。
# It works fine in most of the case, but it may cause some problems after
# the bias addition manipulation.
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
# added by the bias addition manipulation following the get_attr node.
convert_to_parameter = False
if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0],
torch.nn.parameter.Parameter):
convert_to_parameter = True
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
if convert_to_parameter:
meta_out = torch.nn.Parameter(meta_out)
elif kind == "call_method":
# Our meta data will not record the nn.parameter.Parameter attribute。
# It works fine in most of the case, but it may cause some problems after
# the bias addition manipulation.
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
# added by the bias addition manipulation following the get_attr node.
convert_to_parameter = False
if target in (torch.Tensor.view,) and isinstance(args_metas[0], torch.nn.parameter.Parameter):
convert_to_parameter = True
method = getattr(args_metas[0].__class__, target)
# fetch patched method
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
else:
meta_target = method
meta_out = meta_target(*args_metas, **kwargs_metas)
if convert_to_parameter:
meta_out = torch.nn.Parameter(meta_out)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if meta_patched_module.has(mod_type):
meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas)
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.parameter.Parameter):
meta_out = torch.nn.Parameter(attr_itr.to(device="meta"))
elif isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
finally:
self._disable_module_getattr = False
else:
return None
except Exception as e:
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
return meta_out
def trace(self,
root: nn.Module,
concrete_args: Optional[Dict[str, Tensor]] = None,
......@@ -383,7 +470,7 @@ class ColoTracer(Tracer):
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
setattr(node, 'activation_checkpoint', self.act_ckpt_region_count)
node.meta['activation_checkpoint'] = self.act_ckpt_region_count
return node
......
from .chunk import TensorInfo, TensorState
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .gemini_mgr import GeminiManager
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
from .gemini_mgr import GeminiManager
__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState']
__all__ = [
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager',
'search_chunk_configuration'
]
from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
from .manager import ChunkManager
from .search_utils import clasify_params, search_chunk_configuration
from .search_utils import classify_params_by_dp_degree, search_chunk_configuration
from .utils import init_chunk_manager
__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager']
import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, List
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import get_current_device
class TensorState(Enum):
......@@ -17,9 +18,9 @@ class TensorState(Enum):
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE,
TensorState.HOLD),
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD))
......@@ -50,7 +51,6 @@ def alloc_storage(tensor: torch.Tensor) -> None:
class Chunk:
_total_number = 0
def __init__(self,
......@@ -58,6 +58,7 @@ class Chunk:
process_group: ColoProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False) -> None:
"""
......@@ -70,8 +71,9 @@ class Chunk:
chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU.
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
......@@ -80,13 +82,12 @@ class Chunk:
self.chunk_size = chunk_size
self.utilized_size = 0
# Here, we use torch process group,
# since ColoProcessGroup might get deprecated soon
self.torch_pg = process_group.dp_process_group()
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
# the chunk size should be able to be divied by the size of GPU
# the chunk size should be divisible by the dp degree
if not keep_gathered:
assert chunk_size % self.pg_size == 0
self.shard_size = chunk_size // self.pg_size
......@@ -96,26 +97,41 @@ class Chunk:
self.dtype = dtype
device = init_device or get_current_device()
# chunk_temp is a global chunk, which only exists during building the chunks.
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.chunk_total = None # we force chunk_total located in CUDA
self.cuda_shard = None # using two attributes for the better interpretation
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
# cuda local chunk, which is sharded on GPUs
self.cuda_shard = None
# cpu local chunk, which is sharded on CPUs
self.cpu_shard = None
# is the chunks gathers, which means chunks are duplicated on each process,
# and we should use the cuda_global_chunk.
self.is_gathered = True
# configure the init device of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size
# each tensor is associated with a TensorInfo to track meta info
# each tensor is associated with a TensorInfo to track its meta info
# (state, offset, end)
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
# the total number of all tensors
# the total number of tensors in the chunk
self.num_tensors = 0
# monitor the states of all tensors
self.tensors_state_monitor: Dict[TensorState, int] = dict()
# Record the number of tensors in different states
self.tensor_state_cnter: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensors_state_monitor[state] = 0
self.tensor_state_cnter[state] = 0
# some chunks can keep gathered all the time
# so their computation patterns are the same as that of the parameters in DDP
# If a chunk is kept gathered,
# they are treated the same as that of the parameters in DDP during training.
self.keep_gathered = keep_gathered
if self.keep_gathered:
pin_memory = False # since this chunk is gathered, it doesn't need to pin
......@@ -133,6 +149,10 @@ class Chunk:
# if the cpu_shard has been visited during the training step, the flag is True
self.cpu_vis_flag = False
# whether to record l2 norm for the gradient clipping calculation
self.l2_norm_flag = False
self.l2_norm = None
@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
......@@ -172,7 +192,7 @@ class Chunk:
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_total
return self.cuda_global_chunk
elif self.cuda_shard is not None:
return self.cuda_shard
else:
......@@ -197,25 +217,37 @@ class Chunk:
if self.keep_gathered:
return False
else:
return self.tensors_state_monitor[TensorState.HOLD] + \
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
return self.tensor_state_cnter[TensorState.HOLD] + \
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property
def can_reduce(self):
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors
@property
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values in CUDA.
"""Check if the chunk has inf or nan values on CUDA.
"""
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
else:
assert self.cuda_shard is not None # only check in CUDA
assert self.cuda_shard is not None # only check on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def set_l2_norm(self) -> None:
"""Record l2 norm of this chunks on CUDA.
"""
assert self.l2_norm is None, "you are calculating the l2 norm twice"
if self.is_gathered:
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
else:
assert self.cuda_shard is not None # calculate on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
chunk_l2_norm = valid_tensor.data.float().norm(2)
self.l2_norm = chunk_l2_norm.item()**2
def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.
......@@ -239,14 +271,11 @@ class Chunk:
self.num_tensors += 1
tensor_state = TensorState.HOLD
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.tensors_state_monitor[tensor_state] += 1
self.tensor_state_cnter[tensor_state] += 1
self.utilized_size = new_utilized_size
def close_chunk(self, shard_dev: Optional[torch.device] = None):
def close_chunk(self):
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
Args:
shard_dev: the device where the shard locates
"""
# sanity check
assert self.chunk_temp is not None
......@@ -258,28 +287,23 @@ class Chunk:
self.valid_end = self.utilized_size - self.shard_begin
if self.chunk_temp.device.type == 'cpu':
self.chunk_total = self.chunk_temp.to(get_current_device())
self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
self.__update_tensors_ptr()
else:
self.chunk_total = self.chunk_temp
self.cuda_global_chunk = self.chunk_temp
self.chunk_temp = None
self.__scatter()
# gathered chunk never have shard attribute
if self.keep_gathered:
if shard_dev is None:
shard_dev = get_current_device()
else:
assert shard_dev.type == 'cuda'
elif shard_dev is None:
shard_dev = torch.device('cpu')
return
if self.pin_memory or shard_dev.type == 'cpu':
if self.pin_memory or self.shard_device.type == 'cpu':
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited
if shard_dev.type == 'cpu':
if self.shard_device.type == 'cpu':
self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False):
......@@ -352,19 +376,19 @@ class Chunk:
if self.pg_size == 1:
# tricky code here
# just move chunk_total to cuda_shard
# just move cuda_global_chunk to cuda_shard
# the communication is not necessary
self.__scatter()
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.chunk_total, group=self.torch_pg)
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
free_storage(self.chunk_total)
free_storage(self.cuda_global_chunk)
self.is_gathered = False
self.__update_tensors_state(TensorState.HOLD)
......@@ -399,8 +423,8 @@ class Chunk:
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload.
......@@ -429,7 +453,7 @@ class Chunk:
friend_chunk = self.paired_chunk
if self.is_gathered is True:
assert friend_chunk.is_gathered is True
self.chunk_total.copy_(friend_chunk.chunk_total)
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
self.optim_sync_flag = True
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
self.cuda_shard.copy_(friend_chunk.cuda_shard)
......@@ -451,8 +475,8 @@ class Chunk:
# sanity check
assert self.cuda_shard is not None
alloc_storage(self.chunk_total)
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
self.cuda_shard = None
......@@ -466,11 +490,11 @@ class Chunk:
# sanity check
assert self.cuda_shard is None
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device)
self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end])
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end])
free_storage(self.chunk_total)
free_storage(self.cuda_global_chunk)
self.is_gathered = False
def __paired_shard_move(self):
......@@ -491,15 +515,15 @@ class Chunk:
def __update_tensors_ptr(self) -> None:
# sanity check
assert self.is_gathered
assert type(self.chunk_total) == torch.Tensor
assert type(self.cuda_global_chunk) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensors_state_monitor[tensor_info.state] -= 1
self.tensor_state_cnter[tensor_info.state] -= 1
tensor_info.state = next_state
self.tensors_state_monitor[tensor_info.state] += 1
self.tensor_state_cnter[tensor_info.state] += 1
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
......@@ -529,9 +553,9 @@ class Chunk:
output.append("\tchunk temp:\n")
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
if self.chunk_total is not None and self.chunk_total.storage().size() > 0:
if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0:
output.append("\tchunk total:\n")
print_tensor(tensor=self.chunk_total, prefix='\t\t')
print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t')
if self.cuda_shard is not None:
output.append("\tcuda shard:\n")
......@@ -547,6 +571,6 @@ class Chunk:
if detailed:
output.append("\ttensor state monitor:\n")
for st in TensorState:
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
return ''.join(output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment