Unverified Commit 7ef3507a authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[zero] show model data cuda memory usage after zero context init. (#515)

parent a2e61d61
......@@ -22,13 +22,24 @@ class ModelDataTracer(metaclass=SingletonMeta):
def __init__(self) -> None:
self._cuda_usage = 0
self._start_flag = False
def add_tensor(self, t: torch.Tensor):
def start(self) -> None:
self._start_flag = True
def close(self) -> None:
self._start_flag = False
def add_tensor(self, t: torch.Tensor) -> None:
if not self._start_flag:
return
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
mem_use = _col_tensor_mem_usage(t)
self._cuda_usage += mem_use
def delete_tensor(self, t: torch.Tensor):
def delete_tensor(self, t: torch.Tensor) -> None:
if not self._start_flag:
return
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
mem_use = _col_tensor_mem_usage(t)
self._cuda_usage -= mem_use
......
......@@ -10,6 +10,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from colossalai.logging import get_dist_logger, disable_existing_loggers
# Inserts _post_init_method at the end of init method
......@@ -126,8 +127,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
def _pre_context_exec(self):
"""
The Callback function when entering the context
"""
self.logger = get_dist_logger("ZeroInitContext")
GLOBAL_MODEL_DATA_TRACER.start()
def _post_context_exec(self):
"""The callback function when the context exits.
"""The callback function when exiting context.
"""
if not self.rm_torch_payload_on_the_fly:
for param in self.initialized_param_list:
......@@ -135,9 +143,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param.col_attr.remove_torch_payload()
del self.initialized_param_list
GLOBAL_MODEL_DATA_TRACER.close()
cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0])
def _post_init_method(self, module):
r"""The function to call at the end of the constructor of each nn.Module.
def _post_init_method(self, module: torch.nn.Module):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
for param in module.parameters():
# avoid adapting a param to ShardedParam twice
......@@ -165,7 +178,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param:
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
# if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
......
......@@ -23,9 +23,11 @@ from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tenso
class ShardedModelV2(nn.Module):
"""A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3.
Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
"""
A wrapper for the PyTorch module shards the model parameters among multiple GPU memory.
Only 1/#nproc of parameters, gradients are stored in local CUDA memory, so forward and backward
passes can be executed with limited CUDA memory budget.
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
Args:
......
......@@ -16,6 +16,7 @@ def run_tensor_move(rank):
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
GLOBAL_MODEL_DATA_TRACER.start()
src_t = torch.ones(2, 3).cuda()
GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t)
......@@ -39,6 +40,7 @@ def run_tensor_move(rank):
colo_model_data_tensor_move(src_t, tgt_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
GLOBAL_MODEL_DATA_TRACER.close()
def test_tensor_move():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment