Unverified Commit b7045b19 authored by Cjkkkk's avatar Cjkkkk Committed by GitHub
Browse files

fix buffer transfer bug (#2045)

parent b8c0fb6e
...@@ -55,7 +55,7 @@ def test(model, device, test_loader): ...@@ -55,7 +55,7 @@ def test(model, device, test_loader):
def main(): def main():
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cpu') device = torch.device('cuda')
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
...@@ -66,7 +66,7 @@ def main(): ...@@ -66,7 +66,7 @@ def main():
batch_size=1000, shuffle=True) batch_size=1000, shuffle=True)
model = Mnist() model = Mnist()
model.to(device) model = model.to(device)
'''you can change this to LevelPruner to implement it '''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list) pruner = LevelPruner(configure_list)
...@@ -82,14 +82,14 @@ def main(): ...@@ -82,14 +82,14 @@ def main():
pruner = AGP_Pruner(model, configure_list) pruner = AGP_Pruner(model, configure_list)
model = pruner.compress() model = pruner.compress()
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10): for epoch in range(10):
pruner.update_epoch(epoch) pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch)) print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer) train(model, device, train_loader, optimizer)
test(model, device, test_loader) test(model, device, test_loader)
pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28]) pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28], device)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -226,7 +226,7 @@ class PrunerModuleWrapper(torch.nn.Module): ...@@ -226,7 +226,7 @@ class PrunerModuleWrapper(torch.nn.Module):
# config and pruner # config and pruner
self.config = config self.config = config
self.pruner = pruner self.pruner = pruner
self.registered_buffers = {} self.registered_buffers = []
# register buffer for mask # register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
...@@ -234,16 +234,21 @@ class PrunerModuleWrapper(torch.nn.Module): ...@@ -234,16 +234,21 @@ class PrunerModuleWrapper(torch.nn.Module):
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else: else:
self.register_buffer("bias_mask", None) self.register_buffer("bias_mask", None)
self.registered_buffers.append('weight_mask')
self.registered_buffers['weight_mask'] = self.weight_mask self.registered_buffers.append('bias_mask')
self.registered_buffers['bias_mask'] = self.bias_mask
# register user specified buffer # register user specified buffer
for name in self.pruner.buffers: for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone()) self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name) self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs): def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.registered_buffers) mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.get_registered_buffers())
if mask is not None: if mask is not None:
self.weight_mask.copy_(mask['weight']) self.weight_mask.copy_(mask['weight'])
# apply mask to weight # apply mask to weight
...@@ -399,6 +404,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -399,6 +404,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner # config and pruner
self.config = config self.config = config
self.quantizer = quantizer self.quantizer = quantizer
self.registered_buffers = []
# register buffer and parameter # register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight # old_weight is used to store origin weight and weight is used to store quantized weight
...@@ -413,10 +419,15 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -413,10 +419,15 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.module.register_buffer('weight', self.module.old_weight) self.module.register_buffer('weight', self.module.old_weight)
# register user specified buffer # register user specified buffer
self.registered_buffers = {}
for name in self.quantizer.buffers: for name in self.quantizer.buffers:
self.register_buffer(name, self.quantizer.buffers[name].clone()) self.register_buffer(name, self.quantizer.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name) self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs): def forward(self, *inputs):
if 'input' in self.config['quant_types']: if 'input' in self.config['quant_types']:
...@@ -426,7 +437,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -426,7 +437,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.quantizer.quantize_input, self.quantizer.quantize_input,
self.config, self.config,
LayerInfo(self.name, self.module), LayerInfo(self.name, self.module),
**self.registered_buffers) **self.get_registered_buffers())
if 'weight' in self.config['quant_types'] and _check_weight(self.module): if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply( new_weight = self.quantizer.quant_grad.apply(
...@@ -435,7 +446,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -435,7 +446,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.quantizer.quantize_weight, self.quantizer.quantize_weight,
self.config, self.config,
LayerInfo(self.name, self.module), LayerInfo(self.name, self.module),
**self.registered_buffers) **self.get_registered_buffers())
self.module.weight = new_weight self.module.weight = new_weight
result = self.module(*inputs) result = self.module(*inputs)
else: else:
...@@ -448,7 +459,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -448,7 +459,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.quantizer.quantize_output, self.quantizer.quantize_output,
self.config, self.config,
LayerInfo(self.name, self.module), LayerInfo(self.name, self.module),
**self.registered_buffers) **self.get_registered_buffers())
return result return result
class Quantizer(Compressor): class Quantizer(Compressor):
......
...@@ -170,7 +170,7 @@ class AGP_Pruner(Pruner): ...@@ -170,7 +170,7 @@ class AGP_Pruner(Pruner):
if epoch > 0: if epoch > 0:
self.now_epoch = epoch self.now_epoch = epoch
for wrapper in self.get_modules_wrapper(): for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable wrapper.if_calculated.copy_(torch.tensor(0)) # pylint: disable=not-callable
class SlimPruner(Pruner): class SlimPruner(Pruner):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment