Unverified Commit 0dea1407 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[hotfix] add deconstructor for stateful tensor (#848)

* add deconstructor for stateful tensor

* fix colo init context
parent 0f7ed8c1
...@@ -6,7 +6,7 @@ class GeminiMemoryManager(object): ...@@ -6,7 +6,7 @@ class GeminiMemoryManager(object):
def __init__(self, states_cls: EnumMeta): def __init__(self, states_cls: EnumMeta):
super().__init__() super().__init__()
self.states_cls = states_cls self.states_cls = states_cls
self._cnter = 0 # the counter of instances self._cnter = 0 # the counter of instances
self.total_mem = dict() self.total_mem = dict()
self.state_mem = dict() self.state_mem = dict()
...@@ -20,10 +20,10 @@ class GeminiMemoryManager(object): ...@@ -20,10 +20,10 @@ class GeminiMemoryManager(object):
return self._cnter return self._cnter
def reset(self): def reset(self):
self._cnter = 0 # the counter of instances self._cnter = 0 # the counter of instances
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
# memory conditions for all states # memory conditions for all states
for state in self.states_cls: for state in self.states_cls:
...@@ -33,13 +33,16 @@ class GeminiMemoryManager(object): ...@@ -33,13 +33,16 @@ class GeminiMemoryManager(object):
def register_new_instance(self): def register_new_instance(self):
self._cnter += 1 self._cnter += 1
def delete_instance(self):
self._cnter -= 1
def print_info(self): def print_info(self):
print( print(f"Total number: {self.total_number}",
f"Total number: {self.total_number}", f"Total CPU memory occupation: {self.total_mem['cpu']}",
f"Total CPU memory occupation: {self.total_mem['cpu']}", f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n') sep='\n')
for state in self.states_cls: for state in self.states_cls:
print( print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n') sep='\n')
...@@ -202,3 +202,8 @@ class StatefulTensor(object): ...@@ -202,3 +202,8 @@ class StatefulTensor(object):
# update the information of each state # update the information of each state
manager.state_mem[from_type][state] -= size manager.state_mem[from_type][state] -= size
manager.state_mem[to_type][state] += size manager.state_mem[to_type][state] += size
def __del__(self):
self.set_null()
StatefulTensor.GST_MGR.delete_instance()
del self
...@@ -12,7 +12,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): ...@@ -12,7 +12,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__() super().__init__()
self._lazy_memory_allocate = lazy_memory_allocate self._lazy_memory_allocate = lazy_memory_allocate
def _post_init_method(self, module: torch.nn.Module): def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
""" """
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times? FIXME(fjr) The module may be passed to this function multiple times?
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment