"tests/test_legacy/test_engine/test_engine.py" did not exist on "62f4e2eb0760ac8bfe28834b061dbc2bda93ade9"
Commit c7d49329 authored by LuGY's avatar LuGY Committed by Frank Lee
Browse files

[NFC] polish colossalai/utils/tensor_detector/tensor_detector.py code style (#1566)

parent 0c4c9aa6
......@@ -5,18 +5,17 @@ import torch.nn as nn
from typing import Optional
from collections import defaultdict
LINE_WIDTH = 108
LINE = '-' * LINE_WIDTH + '\n'
class TensorDetector():
def __init__(self,
show_info: bool = True,
log: str = None,
include_cpu: bool = False,
module: Optional[nn.Module] = None
):
module: Optional[nn.Module] = None):
"""This class is a detector to detect tensor on different devices.
Args:
......@@ -57,12 +56,12 @@ class TensorDetector():
def mem_format(self, real_memory_size):
# format the tensor memory into a reasonal magnitude
if real_memory_size >= 2 ** 30:
return str(real_memory_size / (2 ** 30)) + ' GB'
if real_memory_size >= 2 ** 20:
return str(real_memory_size / (2 ** 20)) + ' MB'
if real_memory_size >= 2 ** 10:
return str(real_memory_size / (2 ** 10)) + ' KB'
if real_memory_size >= 2**30:
return str(real_memory_size / (2**30)) + ' GB'
if real_memory_size >= 2**20:
return str(real_memory_size / (2**20)) + ' MB'
if real_memory_size >= 2**10:
return str(real_memory_size / (2**10)) + ' KB'
return str(real_memory_size) + ' B'
def collect_tensors_state(self):
......@@ -125,8 +124,7 @@ class TensorDetector():
minus = outdated + minus
if len(self.order) > 0:
for tensor_id in self.order:
self.info += template_format.format('+',
str(self.tensor_info[tensor_id][0]),
self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]),
str(self.tensor_info[tensor_id][1]),
str(tuple(self.tensor_info[tensor_id][2])),
str(self.tensor_info[tensor_id][3]),
......@@ -137,8 +135,7 @@ class TensorDetector():
self.info += '\n'
if len(minus) > 0:
for tensor_id in minus:
self.info += template_format.format('-',
str(self.saved_tensor_info[tensor_id][0]),
self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]),
str(self.saved_tensor_info[tensor_id][1]),
str(tuple(self.saved_tensor_info[tensor_id][2])),
str(self.saved_tensor_info[tensor_id][3]),
......@@ -148,7 +145,6 @@ class TensorDetector():
# deleted the updated tensor
self.saved_tensor_info.pop(tensor_id)
# trace where is the detect()
locate_info = inspect.stack()[2]
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno)
......@@ -168,7 +164,7 @@ class TensorDetector():
with open(self.log + '.log', 'a') as f:
f.write(self.info)
def detect(self, include_cpu = False):
def detect(self, include_cpu=False):
self.include_cpu = include_cpu
self.collect_tensors_state()
self.print_tensors_state()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment