Unverified Commit c6a1a626 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)

* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12

* [zero] add cpu shard init

* [zero] add tiny example test

* [colo_tensor] fix bugs for torch-1.11
parent 32c1b843
import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, List
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
class TensorState(Enum):
FREE = 0
COMPUTE = 1
HOLD = 2
HOLD_AFTER_BWD = 3
READY_FOR_REDUCE = 4
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD))
@dataclass
class TensorInfo:
state: TensorState
offset: int
end: int
class ChunkFullError(Exception):
pass
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
class Chunk:
_total_number = 0
def __init__(self,
chunk_size: int,
process_group: ColoProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
keep_gathered: bool = False,
pin_memory: bool = False) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
Here we use all-gather operation to gather the whole chunk.
Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters.
It is designed to make the full use of communication and PCIE bandwidth.
Args:
chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized
The default value is None, which is the current GPU
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
self.count_id = Chunk._total_number
Chunk._total_number += 1
self.chunk_size = chunk_size
self.utilized_size = 0
# Here, we use torch process group,
# since ColoProcessGroup might get deprecated soon
self.torch_pg = process_group.dp_process_group()
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
# the chunk size should be able to be divied by the size of GPU
if not keep_gathered:
assert chunk_size % self.pg_size == 0
self.shard_size = chunk_size // self.pg_size
self.shard_begin = self.shard_size * self.pg_rank
self.shard_end = self.shard_begin + self.shard_size
self.valid_end = self.shard_size
self.dtype = dtype
device = init_device or get_current_device()
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.chunk_total = None # we force chunk_total located in CUDA
self.cuda_shard = None # using two attributes for the better interpretation
self.cpu_shard = None
self.is_gathered = True
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size
# each tensor is associated with a TensorInfo to track meta info
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
# the total number of all tensors
self.num_tensors = 0
# monitor the states of all tensors
self.tensors_state_monitor: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensors_state_monitor[state] = 0
# some chunks can keep gathered all the time
# so their computation patterns are the same as that of the parameters in DDP
self.keep_gathered = keep_gathered
if self.keep_gathered:
pin_memory = False # since this chunk is gathered, it doesn't need to pin
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
self.pin_memory = pin_memory
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
self.paired_chunk = None
# if this chunk is synchronized with the optimizer, the flag is True
self.optim_sync_flag = True
# if the cpu_shard has been visited during the training step, the flag is True
self.cpu_vis_flag = False
@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
cpu_memory = 0
if self.chunk_temp is not None:
# this chunk is not closed
if self.chunk_temp.device.type == 'cuda':
cuda_memory += self.chunk_mem
else:
cpu_memory += self.chunk_mem
else:
if self.is_gathered:
cuda_memory += self.chunk_mem
if self.cuda_shard is not None:
cuda_memory += self.shard_mem
if self.cpu_shard is not None:
cpu_memory += self.shard_mem
return dict(cuda=cuda_memory, cpu=cpu_memory)
@property
def device_type(self) -> str:
if self.chunk_temp is not None:
return self.chunk_temp.device.type
else:
if self.is_gathered:
return 'cuda'
elif self.cuda_shard is not None:
return 'cuda'
else:
return 'cpu'
@property
def payload(self) -> torch.Tensor:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_total
elif self.cuda_shard is not None:
return self.cuda_shard
else:
return self.cpu_shard
@property
def payload_mem(self) -> int:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_mem
else:
return self.shard_mem
@property
def can_move(self) -> bool:
return not self.is_gathered
@property
def can_release(self) -> bool:
if self.keep_gathered:
return False
else:
return self.tensors_state_monitor[TensorState.HOLD] + \
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property
def can_reduce(self):
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
@property
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values in CUDA.
"""
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
else:
assert self.cuda_shard is not None # only check in CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
# sanity check
assert self.chunk_temp is not None
assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel()
# raise exception when the chunk size is exceeded
if new_utilized_size > self.chunk_size:
raise ChunkFullError
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
# record all the information about the tensor
self.num_tensors += 1
tensor_state = TensorState.HOLD
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.tensors_state_monitor[tensor_state] += 1
self.utilized_size = new_utilized_size
def close_chunk(self, shard_dev: Optional[torch.device] = None):
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
Args:
shard_dev: the device where the shard locates
"""
# sanity check
assert self.chunk_temp is not None
# calculate the valid end for each shard
if self.utilized_size <= self.shard_begin:
self.valid_end = 0
elif self.utilized_size < self.shard_end:
self.valid_end = self.utilized_size - self.shard_begin
if self.chunk_temp.device.type == 'cpu':
self.chunk_total = self.chunk_temp.to(get_current_device())
self.__update_tensors_ptr()
else:
self.chunk_total = self.chunk_temp
self.chunk_temp = None
self.__scatter()
if self.keep_gathered:
if shard_dev is None:
shard_dev = get_current_device()
else:
assert shard_dev.type == 'cuda'
elif shard_dev is None:
shard_dev = torch.device('cpu')
if self.pin_memory or shard_dev.type == 'cpu':
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited
if shard_dev.type == 'cpu':
self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False):
"""Move the shard tensor in the chunk.
Args:
device: the device to which the shard will move
force_copy: if True, copy function is called mandatorily
"""
# sanity check
assert not self.is_gathered
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
assert device.type == 'cuda', "each chunk should first be moved to CUDA"
self.__paired_shard_move()
self.optim_sync_flag = True
return
if device.type == 'cuda':
assert device == get_current_device(), "can't move chunk to another device"
if self.cuda_shard:
return
self.cuda_shard = self.cpu_shard.to(get_current_device())
if not self.pin_memory:
self.cpu_shard = None
elif device.type == 'cpu':
if self.cuda_shard is None:
return
if self.pin_memory:
if force_copy or not self.cpu_vis_flag:
self.cpu_shard.copy_(self.cuda_shard)
# if cpu_shard has been visited
# copy operation is not need
else:
self.cpu_shard = self.cuda_shard.cpu()
self.cpu_vis_flag = True
self.cuda_shard = None
else:
raise NotImplementedError
def access_chunk(self):
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
"""
# sanity check
assert self.chunk_temp is None
if not self.is_gathered:
self.__gather()
self.__update_tensors_ptr()
def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA.
"""
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
self.__scatter()
def reduce(self):
"""Reduce scatter all the gradients. It's an operation done in CUDA.
"""
# sanity check
assert self.is_gathered
if self.pg_size == 1:
# tricky code here
# just move chunk_total to cuda_shard
# the communication is not necessary
self.__scatter()
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.chunk_total, group=self.torch_pg)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
free_storage(self.chunk_total)
self.is_gathered = False
self.__update_tensors_state(TensorState.HOLD)
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
return
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
# sanity check
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload.
"""
if self.keep_gathered:
return self.utilized_size
else:
return self.valid_end
def init_pair(self, friend_chunk: 'Chunk') -> None:
"""Initialize the paired chunk.
"""
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
self.paired_chunk = friend_chunk
friend_chunk.paired_chunk = self
else:
assert self.paired_chunk is friend_chunk
assert friend_chunk.paired_chunk is self
def optim_update(self) -> None:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
"""
# sanity check
assert self.paired_chunk is not None
friend_chunk = self.paired_chunk
if self.is_gathered is True:
assert friend_chunk.is_gathered is True
self.chunk_total.copy_(friend_chunk.chunk_total)
self.optim_sync_flag = True
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True
self.cpu_vis_flag = False
else:
# optim_sync_flag is set to False
# see shard_move function for more details
assert friend_chunk.device_type == 'cpu'
assert self.device_type == 'cpu'
self.optim_sync_flag = False
self.cpu_vis_flag = False
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
def __gather(self):
if not self.is_gathered:
# sanity check
assert self.cuda_shard is not None
alloc_storage(self.chunk_total)
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
self.cuda_shard = None
self.is_gathered = True
def __scatter(self):
if self.keep_gathered:
return
if self.is_gathered:
# sanity check
assert self.cuda_shard is None
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end])
free_storage(self.chunk_total)
self.is_gathered = False
def __paired_shard_move(self):
assert self.paired_chunk is not None, "chunks should be paired before training"
optim_chunk = self.paired_chunk
assert self.chunk_size == optim_chunk.chunk_size
# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert self.cuda_shard is None
temp = optim_chunk.cpu_shard.to(get_current_device())
# avoid to transform FP32 in CPU
self.cuda_shard = temp.to(self.dtype)
if not self.pin_memory:
self.cpu_shard = None
def __update_tensors_ptr(self) -> None:
# sanity check
assert self.is_gathered
assert type(self.chunk_total) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensors_state_monitor[tensor_info.state] -= 1
tensor_info.state = next_state
self.tensors_state_monitor[tensor_info.state] += 1
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
self.__update_one_tensor_info(tensor_info, next_state)
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
def __repr__(self, detailed: bool = True):
output = [
"Chunk Information:\n",
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
self.pg_size),
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
]
def print_tensor(tensor, prefix=''):
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
tensor.device))
if self.chunk_temp is not None:
output.append("\tchunk temp:\n")
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
if self.chunk_total is not None and self.chunk_total.storage().size() > 0:
output.append("\tchunk total:\n")
print_tensor(tensor=self.chunk_total, prefix='\t\t')
if self.cuda_shard is not None:
output.append("\tcuda shard:\n")
print_tensor(tensor=self.cuda_shard, prefix='\t\t')
if self.cpu_shard is not None:
output.append("\tcpu shard:\n")
print_tensor(tensor=self.cpu_shard, prefix='\t\t')
memory_info = self.memory_usage
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu']))
if detailed:
output.append("\ttensor state monitor:\n")
for st in TensorState:
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
return ''.join(output)
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import get_current_device
class TensorState(Enum):
FREE = 0
COMPUTE = 1
HOLD = 2
HOLD_AFTER_BWD = 3
READY_FOR_REDUCE = 4
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD))
@dataclass
class TensorInfo:
state: TensorState
offset: int
end: int
class ChunkFullError(Exception):
pass
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
class Chunk:
_total_number = 0
def __init__(self,
chunk_size: int,
process_group: ColoProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
Here we use all-gather operation to gather the whole chunk.
Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters.
It is designed to make the full use of communication and PCIE bandwidth.
Args:
chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized
The default value is None, which is the current GPU
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
self.count_id = Chunk._total_number
Chunk._total_number += 1
self.chunk_size = chunk_size
self.utilized_size = 0
# Here, we use torch process group,
# since ColoProcessGroup might get deprecated soon
self.torch_pg = process_group.dp_process_group()
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
# the chunk size should be able to be divied by the size of GPU
if not keep_gathered:
assert chunk_size % self.pg_size == 0
self.shard_size = chunk_size // self.pg_size
self.shard_begin = self.shard_size * self.pg_rank
self.shard_end = self.shard_begin + self.shard_size
self.valid_end = self.shard_size
self.dtype = dtype
device = init_device or get_current_device()
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.chunk_total = None # we force chunk_total located in CUDA
self.cuda_shard = None # using two attributes for the better interpretation
self.cpu_shard = None
self.is_gathered = True
# configure the init deivce of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size
# each tensor is associated with a TensorInfo to track meta info
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
# the total number of all tensors
self.num_tensors = 0
# monitor the states of all tensors
self.tensors_state_monitor: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensors_state_monitor[state] = 0
# some chunks can keep gathered all the time
# so their computation patterns are the same as that of the parameters in DDP
self.keep_gathered = keep_gathered
if self.keep_gathered:
pin_memory = False # since this chunk is gathered, it doesn't need to pin
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
self.pin_memory = pin_memory
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
self.paired_chunk = None
# if this chunk is synchronized with the optimizer, the flag is True
self.optim_sync_flag = True
# if the cpu_shard has been visited during the training step, the flag is True
self.cpu_vis_flag = False
@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
cpu_memory = 0
if self.chunk_temp is not None:
# this chunk is not closed
if self.chunk_temp.device.type == 'cuda':
cuda_memory += self.chunk_mem
else:
cpu_memory += self.chunk_mem
else:
if self.is_gathered:
cuda_memory += self.chunk_mem
if self.cuda_shard is not None:
cuda_memory += self.shard_mem
if self.cpu_shard is not None:
cpu_memory += self.shard_mem
return dict(cuda=cuda_memory, cpu=cpu_memory)
@property
def device_type(self) -> str:
if self.chunk_temp is not None:
return self.chunk_temp.device.type
else:
if self.is_gathered:
return 'cuda'
elif self.cuda_shard is not None:
return 'cuda'
else:
return 'cpu'
@property
def payload(self) -> torch.Tensor:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_total
elif self.cuda_shard is not None:
return self.cuda_shard
else:
return self.cpu_shard
@property
def payload_mem(self) -> int:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_mem
else:
return self.shard_mem
@property
def can_move(self) -> bool:
return not self.is_gathered
@property
def can_release(self) -> bool:
if self.keep_gathered:
return False
else:
return self.tensors_state_monitor[TensorState.HOLD] + \
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property
def can_reduce(self):
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
@property
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values in CUDA.
"""
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
else:
assert self.cuda_shard is not None # only check in CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
# sanity check
assert self.chunk_temp is not None
assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel()
# raise exception when the chunk size is exceeded
if new_utilized_size > self.chunk_size:
raise ChunkFullError
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
# record all the information about the tensor
self.num_tensors += 1
tensor_state = TensorState.HOLD
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.tensors_state_monitor[tensor_state] += 1
self.utilized_size = new_utilized_size
def close_chunk(self):
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
"""
# sanity check
assert self.chunk_temp is not None
# calculate the valid end for each shard
if self.utilized_size <= self.shard_begin:
self.valid_end = 0
elif self.utilized_size < self.shard_end:
self.valid_end = self.utilized_size - self.shard_begin
if self.chunk_temp.device.type == 'cpu':
self.chunk_total = self.chunk_temp.to(get_current_device())
self.__update_tensors_ptr()
else:
self.chunk_total = self.chunk_temp
self.chunk_temp = None
self.__scatter()
# always gathered chunk does not have shard
if self.keep_gathered:
return
if self.pin_memory or self.shard_device.type == 'cpu':
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited
if self.shard_device.type == 'cpu':
self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False):
"""Move the shard tensor in the chunk.
Args:
device: the device to which the shard will move
force_copy: if True, copy function is called mandatorily
"""
# sanity check
assert not self.is_gathered
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
assert device.type == 'cuda', "each chunk should first be moved to CUDA"
self.__paired_shard_move()
self.optim_sync_flag = True
return
if device.type == 'cuda':
assert device == get_current_device(), "can't move chunk to another device"
if self.cuda_shard:
return
self.cuda_shard = self.cpu_shard.to(get_current_device())
if not self.pin_memory:
self.cpu_shard = None
elif device.type == 'cpu':
if self.cuda_shard is None:
return
if self.pin_memory:
if force_copy or not self.cpu_vis_flag:
self.cpu_shard.copy_(self.cuda_shard)
# if cpu_shard has been visited
# copy operation is not need
else:
self.cpu_shard = self.cuda_shard.cpu()
self.cpu_vis_flag = True
self.cuda_shard = None
else:
raise NotImplementedError
def access_chunk(self):
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
"""
# sanity check
assert self.chunk_temp is None
if not self.is_gathered:
self.__gather()
self.__update_tensors_ptr()
def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA.
"""
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
self.__scatter()
def reduce(self):
"""Reduce scatter all the gradients. It's an operation done in CUDA.
"""
# sanity check
assert self.is_gathered
if self.pg_size == 1:
# tricky code here
# just move chunk_total to cuda_shard
# the communication is not necessary
self.__scatter()
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.chunk_total, group=self.torch_pg)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
free_storage(self.chunk_total)
self.is_gathered = False
self.__update_tensors_state(TensorState.HOLD)
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
return
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
# sanity check
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload.
"""
if self.keep_gathered:
return self.utilized_size
else:
return self.valid_end
def init_pair(self, friend_chunk: 'Chunk') -> None:
"""Initialize the paired chunk.
"""
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
self.paired_chunk = friend_chunk
friend_chunk.paired_chunk = self
else:
assert self.paired_chunk is friend_chunk
assert friend_chunk.paired_chunk is self
def optim_update(self) -> None:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
"""
# sanity check
assert self.paired_chunk is not None
friend_chunk = self.paired_chunk
if self.is_gathered is True:
assert friend_chunk.is_gathered is True
self.chunk_total.copy_(friend_chunk.chunk_total)
self.optim_sync_flag = True
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True
self.cpu_vis_flag = False
else:
# optim_sync_flag is set to False
# see shard_move function for more details
assert friend_chunk.device_type == 'cpu'
assert self.device_type == 'cpu'
self.optim_sync_flag = False
self.cpu_vis_flag = False
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
def __gather(self):
if not self.is_gathered:
# sanity check
assert self.cuda_shard is not None
alloc_storage(self.chunk_total)
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
self.cuda_shard = None
self.is_gathered = True
def __scatter(self):
if self.keep_gathered:
return
if self.is_gathered:
# sanity check
assert self.cuda_shard is None
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end])
free_storage(self.chunk_total)
self.is_gathered = False
def __paired_shard_move(self):
assert self.paired_chunk is not None, "chunks should be paired before training"
optim_chunk = self.paired_chunk
assert self.chunk_size == optim_chunk.chunk_size
# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert self.cuda_shard is None
temp = optim_chunk.cpu_shard.to(get_current_device())
# avoid to transform FP32 in CPU
self.cuda_shard = temp.to(self.dtype)
if not self.pin_memory:
self.cpu_shard = None
def __update_tensors_ptr(self) -> None:
# sanity check
assert self.is_gathered
assert type(self.chunk_total) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensors_state_monitor[tensor_info.state] -= 1
tensor_info.state = next_state
self.tensors_state_monitor[tensor_info.state] += 1
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
self.__update_one_tensor_info(tensor_info, next_state)
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
def __repr__(self, detailed: bool = True):
output = [
"Chunk Information:\n",
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
self.pg_size),
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
]
def print_tensor(tensor, prefix=''):
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
tensor.device))
if self.chunk_temp is not None:
output.append("\tchunk temp:\n")
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
if self.chunk_total is not None and self.chunk_total.storage().size() > 0:
output.append("\tchunk total:\n")
print_tensor(tensor=self.chunk_total, prefix='\t\t')
if self.cuda_shard is not None:
output.append("\tcuda shard:\n")
print_tensor(tensor=self.cuda_shard, prefix='\t\t')
if self.cpu_shard is not None:
output.append("\tcpu shard:\n")
print_tensor(tensor=self.cpu_shard, prefix='\t\t')
memory_info = self.memory_usage
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu']))
if detailed:
output.append("\ttensor state monitor:\n")
for st in TensorState:
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
return ''.join(output)
import torch
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque
from colossalai.utils import get_current_device
from colossalai.tensor import ColoTensor
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk
class ChunkManager:
"""
A manager class to manipulate the tensors in chunks.
Args:
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device()
self.size_config: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration
for k, v in self.kwargs_config.items():
self.size_config[k] = v.pop('chunk_size')
v['init_device'] = self.device
self.chunk_groups: Dict[str, Deque] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None:
"""Append a tensor to a chunk.
Args:
tensor: the tensor appended to the chunk
group_type: the data type of the group
config_key: the key of the group's name, usually the size of the dp world
pin_memory: whether the chunk is pinned in the cpu memory
"""
assert tensor not in self.tensor_chunk_map
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
assert config_key in self.size_config
chunk_size = self.size_config[config_key]
chunk_kwargs = self.kwargs_config[config_key]
group_name = "{}_{}".format(group_type, config_key)
chunk_group = self.__get_chunk_group(group_name)
try:
# append the tensor to the last chunk
chunk_group[-1].append_tensor(tensor)
except (IndexError, ChunkFullError):
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
if chunk_group:
# the chunk group is not empty
# close the last chunk
self.__close_one_chunk(chunk_group[-1])
if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
chunk = Chunk(
chunk_size=chunk_size,
process_group=tensor.process_group,
dtype=tensor.dtype,
pin_memory=pin_memory,
**chunk_kwargs,
)
chunk_group.append(chunk)
chunk.append_tensor(tensor)
self.__add_memory_usage(chunk.memory_usage)
self.tensor_chunk_map[tensor] = chunk_group[-1]
def close_all_groups(self):
"""Close all the chunks of all groups.
"""
for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1])
def access_chunk(self, chunk: Chunk) -> None:
"""Make the chunk can be used for calculation.
"""
if chunk in self.accessed_chunks:
return
self.__sub_memroy_usage(chunk.memory_usage)
if chunk.device_type == 'cpu':
chunk.shard_move(get_current_device())
self.__add_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA.
"""
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
self.__sub_memroy_usage(chunk.memory_usage)
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
"""Move the shard of the chunk to the target device.
"""
if not chunk.can_move or chunk.device_type == device.type:
return
self.__sub_memroy_usage(chunk.memory_usage)
chunk.shard_move(device, force_copy)
self.__add_memory_usage(chunk.memory_usage)
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
"""Transit tensor state according to pre-defined state machine.
"""
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, chunk: Chunk) -> bool:
"""Reduce or all reduce the chunk.
"""
if not chunk.can_reduce:
return False
self.__sub_memroy_usage(chunk.memory_usage)
chunk.reduce()
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
return True
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
"""
Copy data to the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data (torch.Tensor): the tensor to be copied to the chunk
"""
chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data)
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
"""
Return the chunk owning the tensor.
Args:
tensor (torch.Tensor): a torch tensor object
"""
return self.tensor_chunk_map[tensor]
def get_cuda_movable_chunks(self) -> List[Chunk]:
"""
Get all chunks that can be moved.
"""
chunk_list = []
for chunk in self.accessed_chunks:
if chunk.can_release:
chunk_list.append(chunk)
chunk_list.sort(key=lambda x: x.count_id)
return chunk_list
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
"""
Get all chunks owning the input tensors.
Args:
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
"""
chunks = []
for tensor in tensors:
chunk = self.get_chunk(tensor)
if chunk not in chunks:
chunks.append(chunk)
return tuple(chunks)
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
def __repr__(self) -> str:
msg = [
'Chunk Manager Information:\n',
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
]
for group_name, group in self.chunk_groups.items():
msg.append(f'Group {group_name}:\n')
for i, chunk in enumerate(group):
msg.append(f'[{i}] {chunk}\n')
return ''.join(msg)
def __get_chunk_group(self, group_name: str) -> Deque:
"""Register a chunk group.
"""
if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque()
return self.chunk_groups[group_name]
def __close_one_chunk(self, chunk: Chunk):
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda
self.__sub_memroy_usage(chunk.memory_usage)
chunk.close_chunk(device)
self.__add_memory_usage(chunk.memory_usage)
def __sub_memroy_usage(self, usage: Dict[str, int]):
for k, v in usage.items():
self.total_mem[k] -= v
def __add_memory_usage(self, usage: Dict[str, int]):
for k, v in usage.items():
self.total_mem[k] += v
def __add_accessed_chunk(self, chunk: Chunk):
chunk.access_chunk()
self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem
def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk()
self.accessed_chunks.remove(chunk)
self.accessed_mem -= chunk.chunk_mem
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device
class ChunkManager:
"""
A manager class to manipulate the tensors in chunks.
Args:
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device()
self.size_config: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration
for k, v in self.kwargs_config.items():
self.size_config[k] = v.pop('chunk_size')
v['init_device'] = self.device
self.chunk_groups: Dict[str, Deque] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self,
tensor: ColoTensor,
group_type: str,
config_key: int,
cpu_offload: bool = False,
pin_memory: bool = False) -> None:
"""Append a tensor to a chunk.
Args:
tensor: the tensor appended to the chunk
group_type: the data type of the group
config_key: the key of the group's name, usually the size of the dp world
cpu_offload: if True, the chunk will be closed on CPU
pin_memory: whether the chunk is pinned in the cpu memory
"""
assert tensor not in self.tensor_chunk_map
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
assert config_key in self.size_config
chunk_size = self.size_config[config_key]
chunk_kwargs = self.kwargs_config[config_key]
group_name = "{}_{}".format(group_type, config_key)
chunk_group = self.__get_chunk_group(group_name)
try:
# append the tensor to the last chunk
chunk_group[-1].append_tensor(tensor)
except (IndexError, ChunkFullError):
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
if chunk_group:
# the chunk group is not empty
# close the last chunk
self.__close_one_chunk(chunk_group[-1])
if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
chunk = Chunk(
chunk_size=chunk_size,
process_group=tensor.process_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
**chunk_kwargs,
)
chunk_group.append(chunk)
chunk.append_tensor(tensor)
self.__add_memory_usage(chunk.memory_usage)
self.tensor_chunk_map[tensor] = chunk_group[-1]
def close_all_groups(self):
"""Close all the chunks of all groups.
"""
for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1])
def access_chunk(self, chunk: Chunk) -> None:
"""Make the chunk can be used for calculation.
"""
if chunk in self.accessed_chunks:
return
self.__sub_memroy_usage(chunk.memory_usage)
if chunk.device_type == 'cpu':
chunk.shard_move(get_current_device())
self.__add_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA.
"""
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
self.__sub_memroy_usage(chunk.memory_usage)
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
"""Move the shard of the chunk to the target device.
"""
if not chunk.can_move or chunk.device_type == device.type:
return
self.__sub_memroy_usage(chunk.memory_usage)
chunk.shard_move(device, force_copy)
self.__add_memory_usage(chunk.memory_usage)
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
"""Transit tensor state according to pre-defined state machine.
"""
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, chunk: Chunk) -> bool:
"""Reduce or all reduce the chunk.
"""
if not chunk.can_reduce:
return False
self.__sub_memroy_usage(chunk.memory_usage)
chunk.reduce()
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
return True
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
"""
Copy data to the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data (torch.Tensor): the tensor to be copied to the chunk
"""
chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data)
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
"""
Return the chunk owning the tensor.
Args:
tensor (torch.Tensor): a torch tensor object
"""
return self.tensor_chunk_map[tensor]
def get_cuda_movable_chunks(self) -> List[Chunk]:
"""
Get all chunks that can be moved.
"""
chunk_list = []
for chunk in self.accessed_chunks:
if chunk.can_release:
chunk_list.append(chunk)
chunk_list.sort(key=lambda x: x.count_id)
return chunk_list
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
"""
Get all chunks owning the input tensors.
Args:
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
"""
chunks = []
for tensor in tensors:
chunk = self.get_chunk(tensor)
if chunk not in chunks:
chunks.append(chunk)
return tuple(chunks)
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
def __repr__(self) -> str:
msg = [
'Chunk Manager Information:\n',
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
]
for group_name, group in self.chunk_groups.items():
msg.append(f'Group {group_name}:\n')
for i, chunk in enumerate(group):
msg.append(f'[{i}] {chunk}\n')
return ''.join(msg)
def __get_chunk_group(self, group_name: str) -> Deque:
"""Register a chunk group.
"""
if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque()
return self.chunk_groups[group_name]
def __close_one_chunk(self, chunk: Chunk):
self.__sub_memroy_usage(chunk.memory_usage)
chunk.close_chunk()
self.__add_memory_usage(chunk.memory_usage)
def __sub_memroy_usage(self, usage: Dict[str, int]):
for k, v in usage.items():
self.total_mem[k] -= v
def __add_memory_usage(self, usage: Dict[str, int]):
for k, v in usage.items():
self.total_mem[k] += v
def __add_accessed_chunk(self, chunk: Chunk):
chunk.access_chunk()
self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem
def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk()
self.accessed_chunks.remove(chunk)
self.accessed_mem -= chunk.chunk_mem
import torch
import functools
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time
from typing import List, Optional, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from .placement_policy import PlacementPolicyFactory
......@@ -25,6 +28,7 @@ class GeminiManager:
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
......
import torch
import itertools
import torch.distributed as dist
from collections import OrderedDict
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List, Optional, Set
import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.logging import get_dist_logger
from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .reducer import Reducer
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from .reducer import Reducer
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
......@@ -221,6 +224,7 @@ class ZeroDDP(ColoDDP):
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}
cpu_offload = self.gemini_manager.policy_name != 'cuda'
# TODO: get param order and filter unused params
for p in module.parameters():
assert isinstance(p, ColoParameter)
......@@ -232,10 +236,17 @@ class ZeroDDP(ColoDDP):
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
self.chunk_manager.append_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.append_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
......@@ -247,6 +258,10 @@ class ZeroDDP(ColoDDP):
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
self._logger = get_dist_logger()
def forward(self, *args, **kwargs):
......
from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy
import torch
from functools import lru_cache
from typing import Callable, Optional, Set
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import ProcessGroup, ReplicaSpec
import torch
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional, Set, Callable
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS
@lru_cache(None)
......@@ -57,25 +58,26 @@ class ColoTensor(torch.Tensor):
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
"""
torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
"""
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization.
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
......@@ -112,7 +114,7 @@ class ColoTensor(torch.Tensor):
return self.process_group
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid.
......@@ -135,7 +137,7 @@ class ColoTensor(torch.Tensor):
return self.process_group.tp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec
"""set_dist_spec
set dist spec and change the payloads.
Args:
......@@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor):
if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func]
if cls.torch_minor >= 12:
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if func is torch.Tensor.backward:
assert len(args) == 1 # only has 1 paramter
backward_tensor = torch.Tensor(args[0])
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in _get_my_nowrap_functions():
......@@ -178,7 +190,7 @@ class ColoTensor(torch.Tensor):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute
"""_redistribute
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
......@@ -191,12 +203,12 @@ class ColoTensor(torch.Tensor):
self.dist_spec = dist_spec
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
"""redistribute
"""redistribute
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
DP process group not changed.
2. If the pg is not not None and not equal to the current process group.
First, convert the tensor as replicated among the TP process group.
Second, reset the process group to the new pg.
......@@ -220,7 +232,7 @@ class ColoTensor(torch.Tensor):
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
def to_replicate_(self):
"""to_replicate_
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
......
from enum import Enum
from typing import Dict, Set, Tuple
import torch
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from torch.nn import Parameter
from colossalai.nn.parallel.data_parallel import ZeroDDP
from typing import Dict, Tuple, Set
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device
class OptimState(Enum):
......@@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param]
if local_chunk.keep_gathered:
return param_info.offset, param_info.end
begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end
......
import torch
import colossalai
import pytest
import torch.multiprocessing as mp
import torch.distributed as dist
from functools import partial
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port, get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.gemini import TensorState
from colossalai.gemini.chunk import Chunk
def dist_sum(x):
temp = torch.tensor([x], device=get_current_device())
dist.all_reduce(temp)
return temp.item()
def add_param(param_list, param_cp_list, *args, **kwargs):
param = ColoParameter(torch.randn(*args, **kwargs))
param_list.append(param)
param_cp_list.append(param.clone())
def check_euqal(param, param_cp):
if param.device != param_cp.device:
temp = param.data.to(param_cp.device)
else:
temp = param.data
return torch.equal(temp, param_cp.data)
@parameterize('init_device', [None, torch.device('cpu')])
@parameterize('keep_gathered', [True, False])
@parameterize('pin_memory', [True, False])
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
world_size = torch.distributed.get_world_size()
pg = ColoProcessGroup()
my_chunk = Chunk(chunk_size=1024,
process_group=pg,
dtype=torch.float32,
init_device=init_device,
keep_gathered=keep_gathered,
pin_memory=pin_memory)
param_list = []
param_cp_list = []
add_param(param_list, param_cp_list, 8, 8, 8, device='cuda')
add_param(param_list, param_cp_list, 4, 4)
add_param(param_list, param_cp_list, 4, 8, 2, device='cuda')
add_param(param_list, param_cp_list, 1, 1, 5)
for param in param_list:
my_chunk.append_tensor(param)
assert my_chunk.utilized_size == 597
for param, param_cp in zip(param_list, param_cp_list):
check_euqal(param, param_cp)
my_chunk.close_chunk()
if keep_gathered is False:
assert my_chunk.cpu_shard.size(0) == 1024 // world_size
assert my_chunk.device_type == 'cpu'
assert my_chunk.can_move
my_chunk.shard_move(get_current_device())
else:
assert my_chunk.chunk_total.size(0) == 1024
assert my_chunk.device_type == 'cuda'
assert not my_chunk.can_move
assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size
flag = my_chunk.has_inf_or_nan
assert not flag, "has_inf_or_nan is {}".format(flag)
my_chunk.access_chunk()
assert my_chunk.device_type == 'cuda'
for param, param_cp in zip(param_list, param_cp_list):
check_euqal(param, param_cp)
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3
assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1
assert not my_chunk.can_release
for param in param_list:
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4
assert my_chunk.can_reduce
my_chunk.reduce()
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
if keep_gathered is False:
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
assert my_chunk.device_type == 'cuda'
assert my_chunk.can_move
else:
assert my_chunk.chunk_total.size(0) == 1024
assert my_chunk.device_type == 'cuda'
assert not my_chunk.can_move
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_chunk_basic()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@rerun_if_address_is_in_use()
def test_chunk_function(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_function(4)
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini import TensorState
from colossalai.gemini.chunk import Chunk
from colossalai.tensor import ColoParameter
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
def dist_sum(x):
temp = torch.tensor([x], device=get_current_device())
dist.all_reduce(temp)
return temp.item()
def add_param(param_list, param_cp_list, *args, **kwargs):
param = ColoParameter(torch.randn(*args, **kwargs))
param_list.append(param)
param_cp_list.append(param.clone())
def check_euqal(param, param_cp):
if param.device != param_cp.device:
temp = param.data.to(param_cp.device)
else:
temp = param.data
return torch.equal(temp, param_cp.data)
@parameterize('init_device', [None, torch.device('cpu')])
@parameterize('keep_gathered', [True, False])
@parameterize('pin_memory', [True, False])
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
world_size = torch.distributed.get_world_size()
pg = ColoProcessGroup()
my_chunk = Chunk(chunk_size=1024,
process_group=pg,
dtype=torch.float32,
init_device=init_device,
cpu_shard_init=True,
keep_gathered=keep_gathered,
pin_memory=pin_memory)
param_list = []
param_cp_list = []
add_param(param_list, param_cp_list, 8, 8, 8, device='cuda')
add_param(param_list, param_cp_list, 4, 4)
add_param(param_list, param_cp_list, 4, 8, 2, device='cuda')
add_param(param_list, param_cp_list, 1, 1, 5)
for param in param_list:
my_chunk.append_tensor(param)
assert my_chunk.utilized_size == 597
for param, param_cp in zip(param_list, param_cp_list):
check_euqal(param, param_cp)
my_chunk.close_chunk()
if keep_gathered is False:
assert my_chunk.cpu_shard.size(0) == 1024 // world_size
assert my_chunk.device_type == 'cpu'
assert my_chunk.can_move
my_chunk.shard_move(get_current_device())
else:
assert my_chunk.chunk_total.size(0) == 1024
assert my_chunk.device_type == 'cuda'
assert not my_chunk.can_move
assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size
flag = my_chunk.has_inf_or_nan
assert not flag, "has_inf_or_nan is {}".format(flag)
my_chunk.access_chunk()
assert my_chunk.device_type == 'cuda'
for param, param_cp in zip(param_list, param_cp_list):
check_euqal(param, param_cp)
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3
assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1
assert not my_chunk.can_release
for param in param_list:
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4
assert my_chunk.can_reduce
my_chunk.reduce()
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
if keep_gathered is False:
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
assert my_chunk.device_type == 'cuda'
assert my_chunk.can_move
else:
assert my_chunk.chunk_total.size(0) == 1024
assert my_chunk.device_type == 'cuda'
assert not my_chunk.can_move
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_chunk_basic()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@rerun_if_address_is_in_use()
def test_chunk_function(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_function(4)
......@@ -40,7 +40,8 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy):
@parameterize('keep_gather', [False, True])
def exam_gpt_fwd_bwd(placement_policy, keep_gather):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
......@@ -55,7 +56,7 @@ def exam_gpt_fwd_bwd(placement_policy):
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
......@@ -101,4 +102,4 @@ def test_gpt(world_size):
if __name__ == '__main__':
test_gpt(1)
test_gpt(4)
......@@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
......@@ -98,10 +98,55 @@ def exam_gpt_fwd_bwd(placement_policy):
check_param(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu'])
def exam_tiny_example(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval()
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
exam_tiny_example()
@pytest.mark.dist
......@@ -113,4 +158,4 @@ def test_gpt(world_size):
if __name__ == '__main__':
test_gpt(1)
test_gpt(2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment