Unverified Commit d6789862 authored by Cao Yuhang's avatar Cao Yuhang Committed by GitHub
Browse files

fix mem cal bug (#398)

parent d5cbf7ee
...@@ -50,10 +50,11 @@ class TextLoggerHook(LoggerHook): ...@@ -50,10 +50,11 @@ class TextLoggerHook(LoggerHook):
self._dump_log(runner.meta, runner) self._dump_log(runner.meta, runner)
def _get_max_memory(self, runner): def _get_max_memory(self, runner):
mem = torch.cuda.max_memory_allocated() device = runner.model.output_device
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([mem / (1024 * 1024)], mem_mb = torch.tensor([mem / (1024 * 1024)],
dtype=torch.int, dtype=torch.int,
device=torch.device('cuda')) device=device)
if runner.world_size > 1: if runner.world_size > 1:
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
return mem_mb.item() return mem_mb.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment