Commit 5726795c authored by Cao Yuhang's avatar Cao Yuhang
Browse files

reduce max memory in dist training

parent 4a1ddfc8
import datetime
import torch
import torch.distributed as dist
from .base import LoggerHook
......@@ -39,9 +40,12 @@ class TextLoggerHook(LoggerHook):
# statistic memory
if runner.mode == 'train' and torch.cuda.is_available():
mem = torch.cuda.max_memory_allocated()
mem_mb = int(mem / (1024 * 1024))
mem_str = 'memory: {}, '.format(mem_mb)
log_str += mem_str
mem_mb = torch.tensor([mem / (1024 * 1024)],
dtype=torch.int,
device=torch.device('cuda'))
if runner.world_size > 1:
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
log_str += 'memory: {}, '.format(mem_mb.item())
log_items = []
for name, val in runner.log_buffer.output.items():
if name in ['time', 'data_time']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment