Commit 5726795c authored by Cao Yuhang's avatar Cao Yuhang
Browse files

reduce max memory in dist training

parent 4a1ddfc8
import datetime import datetime
import torch import torch
import torch.distributed as dist
from .base import LoggerHook from .base import LoggerHook
...@@ -39,9 +40,12 @@ class TextLoggerHook(LoggerHook): ...@@ -39,9 +40,12 @@ class TextLoggerHook(LoggerHook):
# statistic memory # statistic memory
if runner.mode == 'train' and torch.cuda.is_available(): if runner.mode == 'train' and torch.cuda.is_available():
mem = torch.cuda.max_memory_allocated() mem = torch.cuda.max_memory_allocated()
mem_mb = int(mem / (1024 * 1024)) mem_mb = torch.tensor([mem / (1024 * 1024)],
mem_str = 'memory: {}, '.format(mem_mb) dtype=torch.int,
log_str += mem_str device=torch.device('cuda'))
if runner.world_size > 1:
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
log_str += 'memory: {}, '.format(mem_mb.item())
log_items = [] log_items = []
for name, val in runner.log_buffer.output.items(): for name, val in runner.log_buffer.output.items():
if name in ['time', 'data_time']: if name in ['time', 'data_time']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment