"runtime/rust/src/vscode:/vscode.git/clone" did not exist on "9d6643b7a59220fc4f3ef599c002241dd0bf9965"
Unverified Commit 56bb412e authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[polish] use GLOBAL_MODEL_DATA_TRACER (#417)

parent 23ba3fc4
......@@ -5,7 +5,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Optional
......@@ -25,7 +25,6 @@ class ZeroHook(BaseOpHook):
def pre_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = []
global_model_data_tracer = ModelDataTracer()
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
......@@ -33,7 +32,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters():
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
param.data = param.col_attr.data.payload
if self._memstarts_collector:
......@@ -50,7 +49,6 @@ class ZeroHook(BaseOpHook):
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
tensor_list = []
global_model_data_tracer = ModelDataTracer()
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
......@@ -58,7 +56,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters():
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
param.data = param.col_attr.data.payload
# Store local accumulated grad shard
if param.grad is not None:
......
import torch
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
def col_move_to_cpu(t: torch.Tensor):
......@@ -7,7 +7,7 @@ def col_move_to_cpu(t: torch.Tensor):
if t.device.type == 'cpu':
return
ModelDataTracer().delete_tensor(t)
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
t.data = t.data.cpu()
......
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from .async_memtracer import get_cuda_memory_used
from colossalai.utils import get_current_device
......@@ -54,7 +54,7 @@ class MemStatsCollector:
if self._start_flag:
sampling_cnt = self._sampling_cnter.sampling_cnt
assert sampling_cnt == len(self._overall_cuda)
self._model_data_cuda.append(ModelDataTracer().cuda_usage)
self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
self._sampling_cnter.advance()
......
......@@ -5,10 +5,9 @@ import torch
class ModelDataTracer(metaclass=SingletonMeta):
"""
A singleton to trace model data usage during runtime.
We have to trigger our API (trace_tensor, detach_tensor) when do model-data memory operation,
including allocation, releasing and moving.
A tracer singleton to trace model data usage during runtime.
The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving.
To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
NOTE() now the class only trace cuda memory usage
"""
......@@ -32,3 +31,6 @@ class ModelDataTracer(metaclass=SingletonMeta):
@property
def cuda_usage(self):
return self._cuda_usage
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
import torch
......@@ -14,7 +14,7 @@ def test_mem_collector():
collector.sample_memstats()
m_a = torch.randn(10).cuda()
ModelDataTracer().add_tensor(m_a)
GLOBAL_MODEL_DATA_TRACER.add_tensor(m_a)
b = torch.randn(10).cuda()
# sampling at time 1
......@@ -26,7 +26,7 @@ def test_mem_collector():
collector.sample_memstats()
collector.finish_collection()
collector.reset()
collector.reset_sampling_cnter()
# do nothing after collection, just advance sampling cnter
collector.sample_memstats()
......
......@@ -3,7 +3,7 @@ import functools
import torch
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
# Inserts _post_init_method at the end of init method
......@@ -153,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
ModelDataTracer().add_tensor(param.col_attr._data_sharded_tensor.payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
if param.col_attr.grad and self.shard_grad:
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
ModelDataTracer().add_tensor(param.col_attr._grad_sharded_tensor.payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
......@@ -26,15 +26,15 @@ def run_naive_amp():
test_models = ['repeated_computed_layers', 'nested_model']
for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name)
model_builder, train_dataloader, _, optim_builder, _ = get_component_func()
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
# create model
amp_model = model_builder(checkpoint=True).cuda()
torch_model = copy.deepcopy(amp_model)
# create optimizer
amp_optimizer = optim_builder(amp_model)
torch_optimizer = optim_builder(torch_model)
amp_optimizer = optim_class(amp_model.parameters(), lr=1e-3)
torch_optimizer = optim_class(torch_model.parameters(), lr=1e-3)
# inject naive amp
amp_config = dict(initial_scale=1)
......
......@@ -14,7 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
def run_dist(rank, world_size, port, init_device, shard_strategy):
......@@ -37,10 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
assert param.col_attr.data.payload.device.type == init_device.type, \
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
print(f'cuda usgae {ModelDataTracer().cuda_usage}')
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
print(f'numel {model_numel_tensor}')
if init_device.type == 'cuda':
assert (ModelDataTracer().cuda_usage > 0)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
@pytest.mark.dist
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment