Unverified Commit be01db37 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[tensor] refactor chunk mgr and impl MemStatsCollectorV2 (#1077)

* polish chunk manager

* polish unit test

* impl add_extern_static_tensor for chunk mgr

* add mem stats collector v2

* polish code

* polish unit test

* polish code

* polish get chunks
parent b3a03e4b
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.utils.memory import colo_device_memory_used
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.tensor import ChunkManager
import torch
import time
......@@ -128,3 +129,19 @@ class MemStatsCollector:
self._start_flag = False
self._step_idx = 0
self._step_total = 0
class MemStatsCollectorV2(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
def sample_model_data(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
......@@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP):
def _post_backward(self):
self.chunk_manager.exec_lazy_release()
for p in self.module.parameters():
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
p.grad = None
else:
p.grad = p.data
......@@ -137,8 +137,8 @@ class ColoDDPV2(ColoDDP):
grad = grad / self.dp_world_size
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
chunk = self.chunk_manager.get_chunk(p)
reduced = self.chunk_manager.reduce_chunk(p)
self.chunk_manager.release_chunk(p)
reduced = self.chunk_manager.reduce_chunk(chunk)
self.chunk_manager.release_chunk(chunk)
if reduced and not chunk.is_free:
self.overflow_counter += chunk.has_inf_or_nan
return empty_grad
......
......@@ -2,7 +2,7 @@ import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Deque, Set, List
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
......@@ -172,6 +172,12 @@ class Chunk:
def device_type(self) -> str:
return self.data.device.type
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
class ChunkManager:
......@@ -226,8 +232,7 @@ class ChunkManager:
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
return src_rank
def access_chunk(self, tensor: torch.Tensor) -> None:
chunk = self.tensor_chunk_map[tensor]
def access_chunk(self, chunk: Chunk) -> None:
if chunk in self.accessed_chunks:
return
if not chunk.is_free:
......@@ -236,10 +241,9 @@ class ChunkManager:
self.accessed_chunks.add(chunk)
self.total_mem[chunk.device_type] += chunk.mem
def release_chunk(self, tensor: torch.Tensor) -> None:
def release_chunk(self, chunk: Chunk) -> None:
if not self.enable_distributed_storage:
return
chunk = self.tensor_chunk_map[tensor]
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
......@@ -248,8 +252,7 @@ class ChunkManager:
if chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
chunk = self.tensor_chunk_map[tensor]
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
if chunk.data.device == device:
return
if chunk.can_move_device and not chunk.is_free:
......@@ -261,8 +264,7 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor]
def reduce_chunk(self, chunk: Chunk) -> bool:
if not chunk.can_reduce:
return False
self.total_mem[chunk.device_type] -= chunk.mem
......@@ -274,10 +276,6 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data)
def is_chunk_free(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor]
return chunk.is_free
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
return self.tensor_chunk_map[tensor]
......@@ -285,8 +283,8 @@ class ChunkManager:
self.lazy_release_tensors.extend(tensors)
def exec_lazy_release(self) -> None:
for tensor in self.lazy_release_tensors:
self.release_chunk(tensor)
for chunk in self.get_chunks(self.lazy_release_tensors):
self.release_chunk(chunk)
self.lazy_release_tensors.clear()
def __repr__(self) -> str:
......@@ -340,3 +338,23 @@ class ChunkManager:
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
if not dest_chunk.is_free:
dest_chunk.copy_(src_chunk)
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
chunks = []
for tensor in tensors:
chunk = self.get_chunk(tensor)
if chunk not in chunks:
chunks.append(chunk)
return tuple(chunks)
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
......@@ -20,12 +20,13 @@ class ZeROHookV2(ParamOpHook):
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release()
# TODO: evict chunks
for p in params:
self._chunk_manager.access_chunk(p)
for chunk in chunks:
self._chunk_manager.access_chunk(chunk)
def post_op(self, params):
for p in params:
......
......@@ -48,7 +48,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def _update_params_ptr(self):
for group in self.optim.param_groups:
for p in group['params']:
if not self.module.chunk_manager.is_chunk_free(p):
if not self.module.chunk_manager.get_chunk(p).is_free:
p.data = self.fp16_param_to_fp32_param[p]
else:
assert p.grad is None
......
......@@ -32,7 +32,7 @@ HAS_TENSORS = {
}
}
TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}}
TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
@parameterize('use_chunk', [False, True])
......@@ -41,8 +41,8 @@ def run_chunk_zero(use_chunk, use_zero):
rank = gpc.get_local_rank(ParallelMode.DATA)
if rank == 0:
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
params = [torch.rand(32, 32) for _ in range(3)]
chunk_size = 2048 if use_chunk else None
params = [torch.rand(8, 8) for _ in range(3)]
chunk_size = 128 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0
......@@ -51,18 +51,19 @@ def run_chunk_zero(use_chunk, use_zero):
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
for p in params:
chunk_manager.access_chunk(p)
chunks = chunk_manager.get_chunks(params)
for chunk in chunks:
chunk_manager.access_chunk(chunk)
check_has_params(params, [True, True, True])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
for p in params:
chunk_manager.release_chunk(p)
for chunk in chunks:
chunk_manager.release_chunk(chunk)
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
for p in params:
chunk_manager.move_chunk(p, torch.device('cpu'))
for chunk in chunks:
chunk_manager.move_chunk(chunk, torch.device('cpu'))
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
assert chunk_manager.total_mem['cuda'] == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment