Unverified Commit 48084d88 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix fpgm (#1767)

parent 06a98372
...@@ -155,18 +155,18 @@ class FPGMPruner(Pruner): ...@@ -155,18 +155,18 @@ class FPGMPruner(Pruner):
return self.mask_dict.get(layer.name) return self.mask_dict.get(layer.name)
try: try:
weight = tf.stop_gradient(tf.transpose(weight, [2, 3, 0, 1])) w = tf.stop_gradient(tf.transpose(tf.reshape(weight, (-1, weight.shape[-1])), [1, 0]))
masks = np.ones(weight.shape) masks = np.ones(w.shape)
num_filters = w.shape[0]
num_kernels = weight.shape[0] * weight.shape[1] num_prune = int(num_filters * config.get('sparsity'))
num_prune = int(num_kernels * config.get('sparsity')) if num_filters < 2 or num_prune < 1:
if num_kernels < 2 or num_prune < 1:
return masks return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) min_gm_idx = self._get_min_gm_kernel_idx(w, num_prune)
for idx in min_gm_idx: for idx in min_gm_idx:
masks[tuple(idx)] = 0. masks[idx] = 0.
finally: finally:
masks = np.transpose(masks, [2, 3, 0, 1]) masks = tf.reshape(tf.transpose(masks, [1, 0]), weight.shape)
masks = tf.Variable(masks) masks = tf.Variable(masks)
self.mask_dict.update({op_name: masks}) self.mask_dict.update({op_name: masks})
self.epoch_pruned_layers.add(layer.name) self.epoch_pruned_layers.add(layer.name)
...@@ -174,22 +174,17 @@ class FPGMPruner(Pruner): ...@@ -174,22 +174,17 @@ class FPGMPruner(Pruner):
return masks return masks
def _get_min_gm_kernel_idx(self, weight, n): def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.shape) >= 3
assert weight.shape[0] * weight.shape[1] > 2
dist_list = [] dist_list = []
for in_i in range(weight.shape[0]): for out_i in range(weight.shape[0]):
for out_i in range(weight.shape[1]): dist_sum = self._get_distance_sum(weight, out_i)
dist_sum = self._get_distance_sum(weight, in_i, out_i) dist_list.append((dist_sum, out_i))
dist_list.append((dist_sum, (in_i, out_i)))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n] min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels] return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, in_idx, out_idx): def _get_distance_sum(self, weight, out_idx):
w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1])) anchor_w = tf.tile(tf.expand_dims(weight[out_idx], 0), [weight.shape[0], 1])
anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0], 1, 1]) x = weight - anchor_w
x = w - anchor_w x = tf.math.reduce_sum((x*x), -1)
x = tf.math.reduce_sum((x*x), (-2, -1))
x = tf.math.sqrt(x) x = tf.math.sqrt(x)
return tf.math.reduce_sum(x) return tf.math.reduce_sum(x)
......
...@@ -215,9 +215,9 @@ class FPGMPruner(Pruner): ...@@ -215,9 +215,9 @@ class FPGMPruner(Pruner):
masks = torch.ones(weight.size()).type_as(weight) masks = torch.ones(weight.size()).type_as(weight)
try: try:
num_kernels = weight.size(0) * weight.size(1) num_filters = weight.size(0)
num_prune = int(num_kernels * config.get('sparsity')) num_prune = int(num_filters * config.get('sparsity'))
if num_kernels < 2 or num_prune < 1: if num_filters < 2 or num_prune < 1:
return masks return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx: for idx in min_gm_idx:
...@@ -233,13 +233,12 @@ class FPGMPruner(Pruner): ...@@ -233,13 +233,12 @@ class FPGMPruner(Pruner):
dist_list = [] dist_list = []
for out_i in range(weight.size(0)): for out_i in range(weight.size(0)):
for in_i in range(weight.size(1)): dist_sum = self._get_distance_sum(weight, out_i)
dist_sum = self._get_distance_sum(weight, out_i, in_i) dist_list.append((dist_sum, out_i))
dist_list.append((dist_sum, (out_i, in_i)))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n] min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels] return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx, in_idx): def _get_distance_sum(self, weight, out_idx):
""" """
Calculate the total distance between a specified filter (by out_idex and in_idx) and Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters. all other filters.
...@@ -257,24 +256,18 @@ class FPGMPruner(Pruner): ...@@ -257,24 +256,18 @@ class FPGMPruner(Pruner):
out_idx: int out_idx: int
output channel index of specified filter, this method calculates the total distance output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters. between this specified filter and all other filters.
in_idx: int
input channel index of specified filter
Returns Returns
------- -------
float32 float32
The total distance The total distance
""" """
logger.debug('weight size: %s', weight.size()) logger.debug('weight size: %s', weight.size())
if len(weight.size()) == 4: # Conv2d assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(-1, weight.size(-2), weight.size(-1))
anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2)) w = weight.view(weight.size(0), -1)
elif len(weight.size()) == 3: # Conv1d anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
w = weight.view(-1, weight.size(-1))
anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1))
else:
raise RuntimeError('unsupported layer type')
x = w - anchor_w x = w - anchor_w
x = (x * x).sum((-2, -1)) x = (x * x).sum(-1)
x = torch.sqrt(x) x = torch.sqrt(x)
return x.sum() return x.sum()
......
...@@ -54,14 +54,8 @@ def tf2(func): ...@@ -54,14 +54,8 @@ def tf2(func):
return test_tf2_func return test_tf2_func
# for fpgm filter pruner test
k1 = [[1] * 3] * 3 w = np.array([[[[i+1]*3]*3]*5 for i in range(10)])
k2 = [[2] * 3] * 3
k3 = [[3] * 3] * 3
k4 = [[4] * 3] * 3
k5 = [[5] * 3] * 3
w = [[k1, k2, k3, k4, k5]] * 10
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
...@@ -92,16 +86,16 @@ class CompressorTestCase(TestCase): ...@@ -92,16 +86,16 @@ class CompressorTestCase(TestCase):
def test_torch_fpgm_pruner(self): def test_torch_fpgm_pruner(self):
""" """
With filters(kernels) defined as above (k1 - k5), it is obvious that k3 is the Geometric Median With filters(kernels) weights defined as above (w), it is obvious that w[4] and w[5] is the Geometric Median
which minimize the total geometric distance by defination of Geometric Median in this paper: which minimize the total geometric distance by defination of Geometric Median in this paper:
Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration, Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
https://arxiv.org/pdf/1811.00250.pdf https://arxiv.org/pdf/1811.00250.pdf
So if sparsity is 0.2, the expected masks should mask out all k3, this can be verified through: So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))` `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))`
If sparsity is 0.6, the expected masks should mask out all k2, k3, k4, this can be verified through: If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))` `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))`
""" """
model = TorchModel() model = TorchModel()
...@@ -111,12 +105,12 @@ class CompressorTestCase(TestCase): ...@@ -111,12 +105,12 @@ class CompressorTestCase(TestCase):
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0]) masks = pruner.calc_mask(layer, config_list[0])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.])) assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1) pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1]) masks = pruner.calc_mask(layer, config_list[1])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.])) assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2 @tf2
def test_tf_fpgm_pruner(self): def test_tf_fpgm_pruner(self):
...@@ -130,16 +124,15 @@ class CompressorTestCase(TestCase): ...@@ -130,16 +124,15 @@ class CompressorTestCase(TestCase):
layer = tf_compressor.compressor.LayerInfo(model.layers[2]) layer = tf_compressor.compressor.LayerInfo(model.layers[2])
masks = pruner.calc_mask(layer, config_list[0]).numpy() masks = pruner.calc_mask(layer, config_list[0]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3]) masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
assert all(masks.sum((0, 2, 3)) == np.array([90., 90., 0., 90., 90.])) assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1) pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()]) model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy() masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3]) masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
assert all(masks.sum((1)) == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.]))
def test_torch_l1filter_pruner(self): def test_torch_l1filter_pruner(self):
""" """
...@@ -208,6 +201,5 @@ class CompressorTestCase(TestCase): ...@@ -208,6 +201,5 @@ class CompressorTestCase(TestCase):
assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment