Commit 08f2920e authored by zhuwenwen's avatar zhuwenwen
Browse files

init colossalai, support dtk2304

parent da3f0934
Pipeline #237 failed with stages
in 0 seconds
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import get_current_device
class TensorState(Enum):
FREE = 0
COMPUTE = 1
HOLD = 2
HOLD_AFTER_BWD = 3
READY_FOR_REDUCE = 4
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD))
@dataclass
class TensorInfo:
state: TensorState
offset: int
end: int
class ChunkFullError(Exception):
pass
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
class Chunk:
_total_number = 0
def __init__(self,
chunk_size: int,
process_group: ColoProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
Here we use all-gather operation to gather the whole chunk.
Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters.
It is designed to make the full use of communication and PCIE bandwidth.
Args:
chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU.
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
self.count_id = Chunk._total_number
Chunk._total_number += 1
self.chunk_size = chunk_size
self.utilized_size = 0
self.torch_pg = process_group.dp_process_group()
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
# the chunk size should be divisible by the dp degree
if not keep_gathered:
assert chunk_size % self.pg_size == 0
self.shard_size = chunk_size // self.pg_size
self.shard_begin = self.shard_size * self.pg_rank
self.shard_end = self.shard_begin + self.shard_size
self.valid_end = self.shard_size
self.dtype = dtype
device = init_device or get_current_device()
# chunk_temp is a global chunk, which only exists during building the chunks.
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
# cuda local chunk, which is sharded on GPUs
self.cuda_shard = None
# cpu local chunk, which is sharded on CPUs
self.cpu_shard = None
# is the chunks gathers, which means chunks are duplicated on each process,
# and we should use the cuda_global_chunk.
self.is_gathered = True
# configure the init device of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size
# each tensor is associated with a TensorInfo to track its meta info
# (state, offset, end)
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
# the total number of tensors in the chunk
self.num_tensors = 0
# Record the number of tensors in different states
self.tensor_state_cnter: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensor_state_cnter[state] = 0
# If a chunk is kept gathered,
# they are treated the same as that of the parameters in DDP during training.
self.keep_gathered = keep_gathered
if self.keep_gathered:
pin_memory = False # since this chunk is gathered, it doesn't need to pin
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
self.pin_memory = pin_memory
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
self.paired_chunk = None
# if this chunk is synchronized with the optimizer, the flag is True
self.optim_sync_flag = True
# if the cpu_shard has been visited during the training step, the flag is True
self.cpu_vis_flag = False
# whether to record l2 norm for the gradient clipping calculation
self.l2_norm_flag = False
self.l2_norm = None
@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
cpu_memory = 0
if self.chunk_temp is not None:
# this chunk is not closed
if self.chunk_temp.device.type == 'cuda':
cuda_memory += self.chunk_mem
else:
cpu_memory += self.chunk_mem
else:
if self.is_gathered:
cuda_memory += self.chunk_mem
if self.cuda_shard is not None:
cuda_memory += self.shard_mem
if self.cpu_shard is not None:
cpu_memory += self.shard_mem
return dict(cuda=cuda_memory, cpu=cpu_memory)
@property
def device_type(self) -> str:
if self.chunk_temp is not None:
return self.chunk_temp.device.type
else:
if self.is_gathered:
return 'cuda'
elif self.cuda_shard is not None:
return 'cuda'
else:
return 'cpu'
@property
def payload(self) -> torch.Tensor:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.cuda_global_chunk
elif self.cuda_shard is not None:
return self.cuda_shard
else:
return self.cpu_shard
@property
def payload_mem(self) -> int:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_mem
else:
return self.shard_mem
@property
def can_move(self) -> bool:
return not self.is_gathered
@property
def can_release(self) -> bool:
if self.keep_gathered:
return False
else:
return self.tensor_state_cnter[TensorState.HOLD] + \
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property
def can_reduce(self):
return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors
@property
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values on CUDA.
"""
if self.is_gathered:
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
else:
assert self.cuda_shard is not None # only check on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def set_l2_norm(self) -> None:
"""Record l2 norm of this chunks on CUDA.
"""
assert self.l2_norm is None, "you are calculating the l2 norm twice"
if self.is_gathered:
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
else:
assert self.cuda_shard is not None # calculate on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
chunk_l2_norm = valid_tensor.data.float().norm(2)
self.l2_norm = chunk_l2_norm.item()**2
def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
# sanity check
assert self.chunk_temp is not None
assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel()
# raise exception when the chunk size is exceeded
if new_utilized_size > self.chunk_size:
raise ChunkFullError
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
# record all the information about the tensor
self.num_tensors += 1
tensor_state = TensorState.HOLD
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.tensor_state_cnter[tensor_state] += 1
self.utilized_size = new_utilized_size
def close_chunk(self):
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
"""
# sanity check
assert self.chunk_temp is not None
# calculate the valid end for each shard
if self.utilized_size <= self.shard_begin:
self.valid_end = 0
elif self.utilized_size < self.shard_end:
self.valid_end = self.utilized_size - self.shard_begin
if self.chunk_temp.device.type == 'cpu':
self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
self.__update_tensors_ptr()
else:
self.cuda_global_chunk = self.chunk_temp
self.chunk_temp = None
self.__scatter()
# gathered chunk never have shard attribute
if self.keep_gathered:
return
if self.pin_memory or self.shard_device.type == 'cpu':
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited
if self.shard_device.type == 'cpu':
self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False):
"""Move the shard tensor in the chunk.
Args:
device: the device to which the shard will move
force_copy: if True, copy function is called mandatorily
"""
# sanity check
assert not self.is_gathered
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
assert device.type == 'cuda', "each chunk should first be moved to CUDA"
self.__paired_shard_move()
self.optim_sync_flag = True
return
if device.type == 'cuda':
assert device == get_current_device(), "can't move chunk to another device"
if self.cuda_shard:
return
self.cuda_shard = self.cpu_shard.to(get_current_device())
if not self.pin_memory:
self.cpu_shard = None
elif device.type == 'cpu':
if self.cuda_shard is None:
return
if self.pin_memory:
if force_copy or not self.cpu_vis_flag:
self.cpu_shard.copy_(self.cuda_shard)
# if cpu_shard has been visited
# copy operation is not need
else:
self.cpu_shard = self.cuda_shard.cpu()
self.cpu_vis_flag = True
self.cuda_shard = None
else:
raise NotImplementedError
def access_chunk(self):
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
"""
# sanity check
assert self.chunk_temp is None
if not self.is_gathered:
self.__gather()
self.__update_tensors_ptr()
def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA.
"""
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
self.__scatter()
def reduce(self):
"""Reduce scatter all the gradients. It's an operation done in CUDA.
"""
# sanity check
assert self.is_gathered
if self.pg_size == 1:
# tricky code here
# just move cuda_global_chunk to cuda_shard
# the communication is not necessary
self.__scatter()
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
free_storage(self.cuda_global_chunk)
self.is_gathered = False
self.__update_tensors_state(TensorState.HOLD)
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
return
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
# sanity check
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload.
"""
if self.keep_gathered:
return self.utilized_size
else:
return self.valid_end
def init_pair(self, friend_chunk: 'Chunk') -> None:
"""Initialize the paired chunk.
"""
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
self.paired_chunk = friend_chunk
friend_chunk.paired_chunk = self
else:
assert self.paired_chunk is friend_chunk
assert friend_chunk.paired_chunk is self
def optim_update(self) -> None:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
"""
# sanity check
assert self.paired_chunk is not None
friend_chunk = self.paired_chunk
if self.is_gathered is True:
assert friend_chunk.is_gathered is True
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
self.optim_sync_flag = True
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True
self.cpu_vis_flag = False
else:
# optim_sync_flag is set to False
# see shard_move function for more details
assert friend_chunk.device_type == 'cpu'
assert self.device_type == 'cpu'
self.optim_sync_flag = False
self.cpu_vis_flag = False
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
def __gather(self):
if not self.is_gathered:
# sanity check
assert self.cuda_shard is not None
alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
self.cuda_shard = None
self.is_gathered = True
def __scatter(self):
if self.keep_gathered:
return
if self.is_gathered:
# sanity check
assert self.cuda_shard is None
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device)
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end])
free_storage(self.cuda_global_chunk)
self.is_gathered = False
def __paired_shard_move(self):
assert self.paired_chunk is not None, "chunks should be paired before training"
optim_chunk = self.paired_chunk
assert self.chunk_size == optim_chunk.chunk_size
# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert self.cuda_shard is None
temp = optim_chunk.cpu_shard.to(get_current_device())
# avoid to transform FP32 in CPU
self.cuda_shard = temp.to(self.dtype)
if not self.pin_memory:
self.cpu_shard = None
def __update_tensors_ptr(self) -> None:
# sanity check
assert self.is_gathered
assert type(self.cuda_global_chunk) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensor_state_cnter[tensor_info.state] -= 1
tensor_info.state = next_state
self.tensor_state_cnter[tensor_info.state] += 1
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
self.__update_one_tensor_info(tensor_info, next_state)
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
def __repr__(self, detailed: bool = True):
output = [
"Chunk Information:\n",
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
self.pg_size),
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
]
def print_tensor(tensor, prefix=''):
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
tensor.device))
if self.chunk_temp is not None:
output.append("\tchunk temp:\n")
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0:
output.append("\tchunk total:\n")
print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t')
if self.cuda_shard is not None:
output.append("\tcuda shard:\n")
print_tensor(tensor=self.cuda_shard, prefix='\t\t')
if self.cpu_shard is not None:
output.append("\tcpu shard:\n")
print_tensor(tensor=self.cpu_shard, prefix='\t\t')
memory_info = self.memory_usage
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu']))
if detailed:
output.append("\ttensor state monitor:\n")
for st in TensorState:
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
return ''.join(output)
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device
class ChunkManager:
"""
A manager class to manipulate the tensors in chunks.
Args:
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration
for k, v in self.kwargs_config.items():
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
v['init_device'] = self.device
self.chunk_groups: Dict[str, Deque] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def register_tensor(self,
tensor: ColoTensor,
group_type: str,
config_key: int,
cpu_offload: bool = False,
pin_memory: bool = False) -> None:
"""
Register a tensor to the chunk manager.
Then, the tensor should be accessed by `get_chunks`.
Args:
tensor: the tensor appended to the chunk
group_type: the data type of the group.
config_key: the key of the group's name, the size of the dp world
cpu_offload: if True, the chunk will be closed on CPU
pin_memory: whether the chunk is pinned in the cpu memory
"""
assert tensor not in self.tensor_chunk_map
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
assert config_key in self.dp_degree_chunk_size_dict
chunk_size = self.dp_degree_chunk_size_dict[config_key]
chunk_kwargs = self.kwargs_config[config_key]
group_name = "{}_{}".format(group_type, config_key)
chunk_group = self.__get_chunk_group(group_name)
try:
# append the tensor to the last chunk
chunk_group[-1].append_tensor(tensor)
except (IndexError, ChunkFullError):
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
if chunk_group:
# the chunk group is not empty
# close the last chunk
self.__close_one_chunk(chunk_group[-1])
if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
chunk = Chunk(
chunk_size=chunk_size,
process_group=tensor.process_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
**chunk_kwargs,
)
chunk_group.append(chunk)
chunk.append_tensor(tensor)
self.__add_memory_usage(chunk.memory_usage)
self.tensor_chunk_map[tensor] = chunk_group[-1]
def close_all_groups(self):
"""Close all the chunks of all groups.
"""
for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1])
def access_chunk(self, chunk: Chunk) -> None:
"""Make the chunk can be used for calculation.
"""
if chunk in self.accessed_chunks:
return
self.__sub_memroy_usage(chunk.memory_usage)
if chunk.device_type == 'cpu':
chunk.shard_move(get_current_device())
self.__add_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA.
"""
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
self.__sub_memroy_usage(chunk.memory_usage)
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
"""Move the shard of the chunk to the target device.
"""
if not chunk.can_move or chunk.device_type == device.type:
return
self.__sub_memroy_usage(chunk.memory_usage)
chunk.shard_move(device, force_copy)
self.__add_memory_usage(chunk.memory_usage)
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
"""Transit tensor state according to pre-defined state machine.
"""
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, chunk: Chunk) -> bool:
"""Reduce or all reduce the chunk.
"""
if not chunk.can_reduce:
return False
self.__sub_memroy_usage(chunk.memory_usage)
chunk.reduce()
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
return True
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
"""
Copy data to the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data (torch.Tensor): the tensor to be copied to the chunk
"""
chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data)
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
"""
Return the chunk owning the tensor.
Args:
tensor (torch.Tensor): a torch tensor object
"""
return self.tensor_chunk_map[tensor]
def get_cuda_movable_chunks(self) -> List[Chunk]:
"""
Get all chunks that can be moved.
"""
chunk_list = []
for chunk in self.accessed_chunks:
if chunk.can_release:
chunk_list.append(chunk)
chunk_list.sort(key=lambda x: x.count_id)
return chunk_list
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
"""
Get all chunks owning the input tensors.
Args:
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
"""
chunks = []
for tensor in tensors:
chunk = self.get_chunk(tensor)
if chunk not in chunks:
chunks.append(chunk)
return tuple(chunks)
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
def __repr__(self) -> str:
msg = [
'Chunk Manager Information:\n',
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
]
for group_name, group in self.chunk_groups.items():
msg.append(f'Group {group_name}:\n')
for i, chunk in enumerate(group):
msg.append(f'[{i}] {chunk}\n')
return ''.join(msg)
def __get_chunk_group(self, group_name: str) -> Deque:
"""Register a chunk group.
"""
if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque()
return self.chunk_groups[group_name]
def __close_one_chunk(self, chunk: Chunk):
self.__sub_memroy_usage(chunk.memory_usage)
chunk.close_chunk()
self.__add_memory_usage(chunk.memory_usage)
def __sub_memroy_usage(self, usage: Dict[str, int]):
for k, v in usage.items():
self.total_mem[k] -= v
def __add_memory_usage(self, usage: Dict[str, int]):
for k, v in usage.items():
self.total_mem[k] += v
def __add_accessed_chunk(self, chunk: Chunk):
chunk.access_chunk()
self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem
def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk()
self.accessed_chunks.remove(chunk)
self.accessed_mem -= chunk.chunk_mem
import math
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch.nn as nn
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
from colossalai.tensor import ColoParameter
def in_ddp(param: nn.Parameter) -> bool:
return not getattr(param, '_ddp_to_ignore', False)
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
"""
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
params_size_arr = np.array(params_size)
std = np.std(params_size_arr)
mean = np.mean(params_size_arr)
upper_limit = mean + 3 * std
for key in size_dict:
org_list = size_dict[key]
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
"""Get unused byte for a certain chunk size.
"""
acc = 0
left = 0
for s in size_list:
if s > left:
acc += left
left = chunk_size
left -= s
return left + acc
def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
Args:
param_order (OrderedParamGenerator): the order of param be visied
Returns:
Dict[int, List[ColoParameter]]: a dict contains the classification results.
The keys are dp_degrees and the values are parameters.
"""
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in param_order.generate():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if not in_ddp(param):
continue
param_key = param.process_group.dp_world_size()
if param_key not in params_dict:
params_dict[param_key] = []
params_dict[param_key].append(param)
return params_dict
def search_chunk_configuration(
model: nn.Module,
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int]:
"""search_chunk_configuration
Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.
search_interval_byte (int): searching interval in byte.
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
Returns:
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
"""
if memstas is not None:
param_order = memstas.param_order()
else:
# build the param visited order right now
param_order = OrderedParamGenerator()
for p in model.parameters():
param_order.append(p)
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
params_dict = classify_params_by_dp_degree(param_order)
config_dict: Dict[int, Dict] = dict()
size_dict: Dict[int, List[int]] = dict()
for dp_degree in params_dict:
params_list = params_dict[dp_degree]
size_list = [p.numel() for p in params_list]
# let small parameters keep gathered in CUDA all the time
total_size = sum(size_list)
if total_size < min_chunk_size_byte:
config_dict[dp_degree] = dict(chunk_size=total_size, keep_gathered=True)
else:
size_dict[dp_degree] = size_list
if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)
max_size = min_chunk_size_byte
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
min_chunk_waste = float('+inf')
best_chunk_size = start_size
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
if temp_waste < min_chunk_waste:
min_chunk_waste = temp_waste
best_chunk_size = chunk_size
for dp_degree in params_dict:
if dp_degree in config_dict:
continue
config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False)
return config_dict, min_chunk_waste
from time import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
from colossalai.gemini.memory_tracer import MemStats
def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
search_range_mb: Optional[float] = None,
min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
kwargs_dict = dict()
if hidden_dim:
search_interval_byte = hidden_dim
else:
search_interval_byte = 1024 # 1kb
kwargs_dict["search_interval_byte"] = search_interval_byte
if search_range_mb:
kwargs_dict["search_range_mb"] = search_range_mb
if min_chunk_size_mb:
kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb
if filter_exlarge_params:
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)]
total_size = sum(params_sizes) / 1024**2
dist.barrier()
begin = time()
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
dist.barrier()
end = time()
span_s = end - begin
wasted_size /= 1024**2
if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
sep='',
flush=True)
dist.barrier()
chunk_manager = ChunkManager(config_dict, init_device)
return chunk_manager
from enum import EnumMeta
class GeminiMemoryManager(object):
def __init__(self, states_cls: EnumMeta):
super().__init__()
self.states_cls = states_cls
self._cnter = 0 # the counter of instances
self.total_mem = dict()
self.state_mem = dict()
self.state_mem['cpu'] = dict()
self.state_mem['cuda'] = dict()
self.reset()
@property
def total_number(self):
return self._cnter
def reset(self):
self._cnter = 0 # the counter of instances
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
# memory conditions for all states
for state in self.states_cls:
self.state_mem['cpu'][state] = 0
self.state_mem['cuda'][state] = 0
def register_new_instance(self):
self._cnter += 1
def delete_instance(self):
self._cnter -= 1
def print_info(self):
print(f"Total number: {self.total_number}",
f"Total CPU memory occupation: {self.total_mem['cpu']}",
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
sep='\n')
for state in self.states_cls:
print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
sep='\n')
import functools
from time import time
from typing import List, Optional, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import MemStats
from .memory_tracer import ChunkMemStatsCollector
from .placement_policy import PlacementPolicyFactory
class GeminiManager:
"""
Stateful Tensor Manager, inspired from PatrickStar
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
Args:
placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'.
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
"""
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager
self._premade_memstats_ = memstats is not None
self._memstats = memstats
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
self._memstats) if policy_cls.need_mem_stats else None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
self._h2d_volume = 0
self._d2h_volume = 0
self._layout_time = 0
self._evict_time = 0
self._warmup = True
self._comp_cuda_demand_time = 0
def memstats(self):
"""memstats
get the memory statistics during training.
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
Note, for the latter, you can not access the memstats before warmup iteration finishes.
"""
if self._premade_memstats_:
return self._memstats
else:
assert not self._warmup, "Gemini Manager has memstats after warm up! Now is during warmup."
return self._mem_stats_collector._memstats
def pre_iter(self, *args):
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.start_collection()
def post_iter(self):
"""This function must be called when each iteration finishes
"""
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.finish_collection()
self._warmup = False
self._compute_idx = -1
self._h2d_volume = 0
self._d2h_volume = 0
self._layout_time = 0
self._evict_time = 0
self._comp_cuda_demand_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start = time()
self._record_chunks_order(chunks)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
self._layout_time += time() - start
vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
compute_idx=self._compute_idx)
self._d2h_volume += vol
self._evict_time += evict_time
# move COMPUTE tensors to CUDA
self._h2d_volume += cuda_demand
@functools.lru_cache(maxsize=None)
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]):
start = time()
cuda_demand = 0
for chunk in chunks:
if chunk.device_type == 'cuda':
if chunk.is_gathered:
pass
else:
cuda_demand += chunk.chunk_mem - chunk.shard_mem
elif chunk.device_type == 'cpu':
cuda_demand += chunk.chunk_mem
else:
raise RuntimeError
self._comp_cuda_demand_time += time() - start
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
return cuda_demand, can_evict_chunks
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
self._compute_idx += 1
if self._warmup and self._placement_policy.need_mem_stats:
self._compute_list.append(chunks)
@property
def default_device(self):
return self._placement_policy.get_default_device()
def sample_overall_data(self):
if self._mem_stats_collector:
self._mem_stats_collector.sample_overall_data()
def record_model_data_volume(self):
if self._mem_stats_collector:
self._mem_stats_collector.record_model_data_volume()
@property
def chunk_manager(self):
return self._chunk_manager
@property
def cuda_margin_mem(self) -> Optional[float]:
if self._mem_stats_collector:
return self._mem_stats_collector.cuda_margin_mem
return None
@property
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats
@staticmethod
def get_default_device(policy_name: str) -> torch.device:
return PlacementPolicyFactory.get_default_device(policy_name)
from .param_runtime_order import OrderedParamGenerator # isort:skip
from .memory_stats import MemStats # isort:skip
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
from .memstats_collector import MemStatsCollector # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
__all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator'
]
from typing import Optional
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.memory_tracer import MemStats
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from .memstats_collector import MemStatsCollector
class ChunkMemStatsCollector(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
"""
Memory Statistic Collector for Chunks.
Args:
chunk_manager (ChunkManager): the chunk manager.
memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None.
"""
super().__init__(memstats)
self._chunk_manager = chunk_manager
# override
def record_model_data_volume(self) -> None:
"""
record model data volumn on cuda and cpu.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_mem = self._chunk_manager.total_mem['cuda']
self._memstats.record_max_cuda_model_data(cuda_mem)
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda')
import json
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from time import sleep, time
import torch
from colossalai.utils import colo_device_memory_used, get_current_device
class MemoryMonitor:
"""Base class for all types of memory monitor.
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
def __init__(self):
self.time_stamps = []
self.mem_stats = []
def __len__(self):
return len(self.mem_stats)
@abstractmethod
def start(self):
pass
@abstractmethod
def finish(self):
pass
def state_dict(self):
return {
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats,
}
def save(self, filename):
with open(filename, "w") as f:
json.dump(self.state_dict(), f)
def clear(self):
self.mem_stats.clear()
self.time_stamps.clear()
class AsyncMemoryMonitor(MemoryMonitor):
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
Usage::
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""
def __init__(self, power: int = 10):
super().__init__()
self.keep_measuring = False
current_device = get_current_device()
def _set_cuda_device():
torch.cuda.set_device(current_device)
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
self.monitor_thread = None
self.interval = 1 / (10**power)
def set_interval(self, power: int):
self.clear()
self.interval = 1 / (10**power)
def is_measuring(self):
return self.keep_measuring
def start(self):
self.keep_measuring = True
self.monitor_thread = self.executor.submit(self._measure_usage)
def finish(self):
if self.keep_measuring is False:
return 0
self.keep_measuring = False
max_usage = self.monitor_thread.result()
self.monitor_thread = None
self.time_stamps.append(time())
self.mem_stats.append(max_usage)
return max_usage
def _measure_usage(self):
max_usage = 0
while self.keep_measuring:
max_usage = max(
max_usage,
colo_device_memory_used(get_current_device()),
)
sleep(self.interval)
return max_usage
class SyncCudaMemoryMonitor(MemoryMonitor):
"""
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def __init__(self, power: int = 10):
super().__init__()
def start(self):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
def finish(self) -> int:
"""
return max gpu memory used since latest `start()`.
Returns:
int: max GPU memory
"""
torch.cuda.synchronize()
self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated()
self.mem_stats.append(max_usage)
return max_usage
from typing import Any, Dict, List, Optional
import torch
from colossalai.gemini.memory_tracer import OrderedParamGenerator
class MemStats(object):
def __init__(self) -> None:
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# (preop_step, List[param])
self._step_param_dict = dict()
# (param, List[preop_step])
self._param_step_dict = dict()
# (preop_step, non_model_data) non model data used during preop_step ~ (preop_step+1)
self._step_nmd_dict = dict()
self._param_runtime_order = OrderedParamGenerator()
self._preop_step = 0
self._prev_overall_cuda = -1
self._max_overall_cuda = 0
self._prev_md_cuda = -1
# old version
self._model_data_cuda_list = []
self._model_data_cpu_list = []
self._overall_cuda_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
def calc_max_cuda_non_model_data(self):
if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1:
max_cuda_non_model_data = self._prev_overall_cuda - self._prev_md_cuda
self._step_nmd_dict[self._preop_step - 1] = max_cuda_non_model_data
# compatibility of the old version.
self._non_model_data_cuda_list.append(max_cuda_non_model_data)
def record_max_cuda_model_data(self, val):
self._prev_md_cuda = val
def record_max_cuda_overall_data(self, val):
self._prev_overall_cuda = val
self._max_overall_cuda = max(self._max_overall_cuda, val)
@property
def max_overall_cuda(self):
return self._max_overall_cuda
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
"""
the time step is increased. param list is used between current and the next
time step.
Args:
param_list (List[torch.nn.Parameter]): a list of torch paramters.
"""
for p in param_list:
if p not in self._param_step_dict:
self._param_step_dict[p] = [self._preop_step]
else:
self._param_step_dict[p].append(self._preop_step)
self._param_runtime_order.append(p)
self._step_param_dict[self._preop_step] = param_list
self._preop_step += 1
def param_used_step(self, param: torch.nn.Parameter) -> Optional[List[int]]:
"""param_used_step
get the timestep list using the param
Args:
param (torch.nn.Parameter): a torch param
Returns:
Optional[List[int]]: a list of int indicates the time step of preop hook.
"""
if param not in self._param_step_dict:
return None
else:
return self._param_step_dict[param]
def param_order(self):
if self._param_runtime_order.is_empty():
raise RuntimeError
else:
return self._param_runtime_order
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list
elif device_type == 'cpu':
return self._non_model_data_cpu_list
else:
raise TypeError
def max_non_model_data(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._non_model_data_cuda_list)
elif device_type == 'cpu':
return max(self._non_model_data_cpu_list)
else:
raise TypeError
def max_overall_cuda(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._overall_cuda_list)
elif device_type == 'cpu':
return max(self._overall_cpu_list)
else:
raise TypeError
def clear(self):
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []
self._param_runtime_order.clear()
self._step_param_dict.clear()
self._param_step_dict.clear()
self._step_nmd_dict.clear()
self._preop_step = 0
self._prev_overall_cuda = -1
self._prev_md_cuda = -1
import time
from typing import List, Optional
import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils.memory import colo_device_memory_used
from .memory_stats import MemStats
class MemStatsCollector:
"""
A Memory statistic collector.
It works in two phases.
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
The first iteration of DNN training.
Phase 2. Runtime Phase: use the read-only collected stats
The rest iterations of DNN training.
It has a Sampling counter which is reset after DNN training iteration.
"""
def __init__(self, memstats: Optional[MemStats] = None) -> None:
self._mem_monitor = SyncCudaMemoryMonitor()
self._sampling_time = []
self._start_flag = False
self._step_idx = 0
self._step_total = 0
if memstats is not None:
self.use_outside_memstats = True
self._memstats = memstats
else:
self.use_outside_memstats = False
self._memstats = MemStats()
def next_period_non_model_data_usage(self, device_type: str) -> int:
"""Maximum non model data memory usage during the next Op run
Args:
device_type (str): device type, can be 'cpu' or 'cuda'.
Returns:
int: max non model data memory usage of current sampling period
"""
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
f"step total {self._step_total}"
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data
@property
def sampling_time(self):
return [t - self._sampling_time[0] for t in self._sampling_time]
def start_collection(self):
print('start collection')
self._start_flag = True
self._mem_monitor.start()
def finish_collection(self):
self.sample_overall_data()
# self._step_total = len(self._sampling_time)
self._step_total = len(self._memstats.non_model_data_list('cuda'))
self._start_flag = False
self._mem_monitor.finish()
print(f'finish_collection {self._step_total}')
# deprecated
def record_model_data_volume(self) -> None:
"""
Sampling model data statistics.
"""
if self._start_flag and not self.use_outside_memstats:
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
self._memstats.append_model_data('cuda', cuda_mem)
self._memstats.append_model_data('cpu', cpu_mem)
def sample_overall_data(self) -> None:
"""
Sampling overall and non model data cuda memory statistics.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_overall = self._mem_monitor.finish()
self._memstats.record_max_cuda_overall_data(cuda_overall)
self._memstats.calc_max_cuda_non_model_data()
self._mem_monitor.start()
if self._start_flag:
self._sampling_time.append(time.time())
def clear(self) -> None:
self._memstats.clear()
self._start_flag = False
self._step_idx = 0
self._step_total = 0
from abc import ABC
import torch
class ParamGenerator(ABC):
def append(self, param: torch.nn.Parameter):
pass
def generate(self):
pass
def clear(self):
pass
class OrderedParamGenerator(ParamGenerator):
"""OrderedParamGenerator
Contain the order of parameters visited during runtime.
"""
def __init__(self) -> None:
self.param_visited_order = []
def append(self, param: torch.nn.Parameter):
self.param_visited_order.append(param)
def generate(self):
visited_set = set()
for p in self.param_visited_order:
if p not in visited_set:
yield p
visited_set.add(p)
del visited_set
def is_empty(self):
return len(self.param_visited_order) == 0
def clear(self):
self.param_visited_order = []
import torch.nn
from colossalai.gemini.memory_tracer import MemStats
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
__all__ = ['RuntimeMemTracer']
class RuntimeMemTracer():
"""RuntimeMemTracer for the module training using ColoParameter.
Trace non-model memory usage during fwd+bwd process.
It is obtained by using a tensor with the same shape as the training process as the inputs
and running an single fwd+bwd to trace the statistics.
NOTE()
1. The premise to use this tracer is that the target DNN execute the same operations at each iterations,
2. Module buffers are viewed as non-model data.
"""
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__()
self.module = module
self.dtype = dtype
self._gradstat = GradMemStats()
self._memstats = MemStats()
self.param_op_hook = ParamMemTracerHook(self._memstats, self._gradstat)
self.grad_hook = GradMemTracerHook(self._gradstat)
self.cpu_param_data_dict = {}
for p in module.parameters():
p.data = p.data.to(dtype)
self._cast_buffers_to_cuda_dtype()
def parameters_in_runtime_order(self):
return self._memstats._param_runtime_order.generate()
def memstats(self):
return self._memstats
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _backup_params(self):
"""
The function is called before forward. Backup model params on cpu.
"""
for p in self.module.parameters():
self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu")
self.cpu_param_data_dict[p].copy_(p.data)
def _restore_params(self):
"""
This function is called after backward. Restore model params.
"""
for p in self.module.parameters():
p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad)
p.data.copy_(self.cpu_param_data_dict[p])
self.cpu_param_data_dict.clear()
def _pre_forward(self):
self._clear_cuda_mem_info()
self._backup_params()
self.grad_hook.register_grad_hook(self.module)
self.param_op_hook.mem_monitor.start()
def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype)
self.module.zero_grad(set_to_none=True)
self._pre_forward()
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
return outputs
def backward(self, loss):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def _post_backward(self):
cuda_volume = self.param_op_hook.mem_monitor.finish()
self._memstats.record_max_cuda_overall_data(cuda_volume)
# calc the last Op non model data
self._memstats.calc_max_cuda_non_model_data()
self.grad_hook.remove_grad_hook()
self._restore_params()
def _clear_cuda_mem_info(self):
self._memstats.clear()
self._gradstat.clear()
def _cast_buffers_to_cuda_dtype(self):
for buffer in self.module.buffers():
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.data.to(self.dtype)
from typing import Optional
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta
from colossalai.gemini.chunk import ChunkManager
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
from .chunk_memstats_collector import ChunkMemStatsCollector
class ModuleInfos:
def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str,
parent_module: torch.nn.Module):
self.module = module
self.module_name = module_name
self.module_full_name = module_full_name
self.parent_module = parent_module
class StaticMemStatsCollector(ChunkMemStatsCollector):
"""
A Static Memory statistic collector.
"""
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
super().__init__(chunk_manager)
self.module = module
self.module_info_list = []
def init_mem_stats(self, *inputs):
self.register_opnodes_recursively(self.module)
self.refactor_module()
self.module = self.module.cpu()
self.module.train()
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
gm = symbolic_trace(self.module)
interp = MetaInfoProp(gm)
interp.propagate(*data)
total_mem = 0
for inp in inputs:
total_mem += inp.numel() * inp.element_size()
last_node = None
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
for node in gm.graph.nodes:
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem)
last_node = node
self._non_model_data_cuda_list.append(total_mem)
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = last_node.meta["fwd_mem_out"]
for node in gm.graph.nodes.__reversed__():
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
total_mem = total_mem - cur_module_mem_fwd
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = node.meta["bwd_mem_out"]
self._step_total = len(self._non_model_data_cuda_list)
self.recover_module()
def refactor_module(self):
for modInfo in self.module_info_list:
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
def recover_module(self):
for modInfo in self.module_info_list:
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
def register_opnodes_recursively(self,
module: torch.nn.Module,
name: str = "",
full_name: str = "",
parent_module: Optional[torch.nn.Module] = None):
assert isinstance(module, torch.nn.Module)
for child_name, child in module.named_children():
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))
from typing import Optional, Tuple
import torch
def colo_model_optimizer_usage(optim) -> Tuple[int, int]:
"""Trace the optimizer memory usage
Args:
optim (ShardedOptimV2): an instance of ShardedOptimver
Returns:
Tuple[int, int]: cuda/cpu memory usage in Byte
"""
if optim is None:
return 0, 0
assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()"
return optim.get_memory_usage()
def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
"""
Trace the model memory usage.
Args:
model (torch.nn.Module): a torch model
Returns:
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
"""
if model is None:
return 0, 0
def _get_tensor_mem_use(t: Optional[torch.Tensor]):
if t is None:
return 0, 0
assert isinstance(t, torch.Tensor)
_cpu_mem_usage, _cuda_mem_usage = 0, 0
if t.device.type == 'cpu':
_cpu_mem_usage += t.numel() * t.element_size()
elif t.device.type == 'cuda':
_cuda_mem_usage += t.numel() * t.element_size()
return _cuda_mem_usage, _cpu_mem_usage
cuda_mem_usage = 0
cpu_mem_usage = 0
for param in model.parameters():
if hasattr(param, 'colo_attr'):
t_cuda, t_cpu = param.colo_attr.get_memory_usage()
cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu
else:
t_cuda, t_cpu = _get_tensor_mem_use(param.data)
cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu
t_cuda, t_cpu = _get_tensor_mem_use(param.grad)
cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu
return cuda_mem_usage, cpu_mem_usage
from .utils import BaseOpHook, register_ophooks_recursively
__all__ = ["BaseOpHook", "register_ophooks_recursively"]
import torch
from colossalai.registry import OPHOOKS
from . import BaseOpHook
@OPHOOKS.register_module
class ShardGradMemTracerHook(BaseOpHook):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
def __init__(self):
super().__init__()
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, '_sharded_grad')
param._sharded_grad.setup()
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
def post_iter(self):
pass
import torch
from colossalai.registry import OPHOOKS
from . import BaseOpHook
@OPHOOKS.register_module
class ShardParamHook(BaseOpHook):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
def __init__(self):
super().__init__()
def niter(self):
return self._niter
def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
param.data = param.ca_attr.payload()
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
param.data = param.ca_attr.payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
param.data = param.ca_attr.payload()
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
param.data = param.ca_attr.payload()
def pre_iter(self):
pass
def post_iter(self):
pass
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import List
import torch
from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ColoParamOpHook
class TrainingPhase(Enum):
FORWARD = 0
BACKWARD = 1
class GradMemStats():
def __init__(self) -> None:
self.unreleased_grad_flag = {}
self.unreleased_grad_volume = 0
def clear(self):
self.unreleased_grad_flag.clear()
self.unreleased_grad_volume = 0
class GradMemTracerHook():
def __init__(self, grad_stats: GradMemStats):
self.grad_hook_list = []
self._grad_stats = grad_stats
def grad_handle(self, p, grad):
assert self._grad_stats.unreleased_grad_flag[p]
free_storage(grad)
self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size()
self._grad_stats.unreleased_grad_flag[p] = False
def register_grad_hook(self, module: torch.nn.Module):
for p in module.parameters():
if p.requires_grad:
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
self._grad_stats.unreleased_grad_flag[p] = False
def remove_grad_hook(self):
for hook in self.grad_hook_list:
hook.remove()
class ParamMemTracerHook(ColoParamOpHook):
def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None:
super().__init__()
self._training_phase = TrainingPhase.FORWARD
self._memstats = memstats
self._grad_stats = gradstats
self.mem_monitor = SyncCudaMemoryMonitor()
def _free_cuda_params(self, params):
for p in params:
if p.data.device.type == "cpu":
raise NotImplementedError("Only free cuda memory")
free_storage(p.data)
def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]):
"""
move params to cuda
Args:
params (List[torch.nn.Parameter]): target params
Raises:
NotImplementedError: raise error when param has cpu grad
"""
for p in params:
cur_dev = p.data.device.type
if cur_dev == "cpu":
if p.grad is not None and p.grad.device.type == "cpu":
raise NotImplementedError("Only run in forward propagation")
p.data = torch.empty(p.data.shape,
device="cuda",
dtype=p.data.dtype,
requires_grad=p.data.requires_grad)
elif cur_dev == "cuda":
alloc_storage(p.data)
def record_model_data_volume(self, params):
"""
get cuda model data used by params
"""
data_volume = self._grad_stats.unreleased_grad_volume
for p in params:
cur_model_data_volume = p.data.numel() * p.data.element_size()
data_volume += cur_model_data_volume
if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad:
# add param.grad, actually param.grad is None in this time
data_volume += cur_model_data_volume
if not self._grad_stats.unreleased_grad_flag[p]:
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
self._grad_stats.unreleased_grad_flag[p] = True
# record max non model data used for this Op
self._memstats.record_max_cuda_model_data(data_volume)
def pre_op(self, params):
max_cuda_used_pre_op = self.mem_monitor.finish()
# record max cuda overall data for prev OP.
self._memstats.record_max_cuda_overall_data(max_cuda_used_pre_op)
# record max cuda non model data for prev OP.
self._memstats.calc_max_cuda_non_model_data()
self._allocate_params_on_cuda(params)
# record max cuda model data for current OP
self.record_model_data_volume(params)
self.mem_monitor.start()
self._memstats.increase_preop_step(params)
def post_op(self, params):
self._free_cuda_params(params)
def pre_forward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_forward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
def pre_backward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_backward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
@contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
old_training_phase = self._training_phase
try:
self._training_phase = training_phase
yield
finally:
self._training_phase = old_training_phase
switch_to_backward = switch_training_phase
switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD)
import torch
from typing import List, Callable, Optional
from abc import ABC, abstractmethod
import torch
class BaseOpHook(ABC):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def __init__(self):
pass
@abstractmethod
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
pass
@abstractmethod
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
@abstractmethod
def post_iter(self):
pass
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
output = output.detach()
ctx.pre_backward_function = pre_backward_function
return output
@staticmethod
def backward(ctx, *args):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx.pre_backward_function(ctx.module)
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook],
name: str = "",
filter_fn: Optional[Callable] = None):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
assert isinstance(ophook_list, (list, tuple))
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
# Add hooks for submodules
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
# return from flitered module
if filter_fn is not None and filter_fn(module):
return
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_fwd_exec(submodule, *args)
def _post_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)
module.register_forward_hook(_pre_backward_module_hook)
module.register_forward_pre_hook(_post_backward_module_hook)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment