Commit 1b5d336d authored by gushiqiao's avatar gushiqiao
Browse files

FIX

parent 39f3609d
......@@ -58,7 +58,7 @@ class Conv2dWeight(Conv2dWeightTemplate):
return destination
def clear(self):
attrs = ["weight", "bias"]
attrs = ["weight", "bias", "pinned_weight", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
......
......@@ -68,7 +68,7 @@ class Conv3dWeight(Conv3dWeightTemplate):
return destination
def clear(self):
attrs = ["weight", "bias"]
attrs = ["weight", "bias", "pinned_weight", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
......
......@@ -145,7 +145,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.pinned_weight = self.pinned_weight.t()
def clear(self):
attrs = ["weight", "weight_scale", "bias"]
attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
......
......@@ -34,7 +34,7 @@ class LNWeightTemplate(metaclass=ABCMeta):
return self.weight.numel() * self.weight.element_size()
def clear(self):
attrs = ["weight", "bias"]
attrs = ["weight", "bias", "pinned_weight", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
......
......@@ -23,7 +23,7 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
def clear(self):
attrs = ["weight"]
attrs = ["weight", "pinned_weight"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
......
......@@ -22,7 +22,7 @@ class DefaultTensor:
self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype)
def clear(self):
attrs = ["tensor"]
attrs = ["tensor", "pinned_tensor"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment