Unverified Commit 48084d88 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix fpgm (#1767)

parent 06a98372
......@@ -155,18 +155,18 @@ class FPGMPruner(Pruner):
return self.mask_dict.get(layer.name)
try:
weight = tf.stop_gradient(tf.transpose(weight, [2, 3, 0, 1]))
masks = np.ones(weight.shape)
num_kernels = weight.shape[0] * weight.shape[1]
num_prune = int(num_kernels * config.get('sparsity'))
if num_kernels < 2 or num_prune < 1:
w = tf.stop_gradient(tf.transpose(tf.reshape(weight, (-1, weight.shape[-1])), [1, 0]))
masks = np.ones(w.shape)
num_filters = w.shape[0]
num_prune = int(num_filters * config.get('sparsity'))
if num_filters < 2 or num_prune < 1:
return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
min_gm_idx = self._get_min_gm_kernel_idx(w, num_prune)
for idx in min_gm_idx:
masks[tuple(idx)] = 0.
masks[idx] = 0.
finally:
masks = np.transpose(masks, [2, 3, 0, 1])
masks = tf.reshape(tf.transpose(masks, [1, 0]), weight.shape)
masks = tf.Variable(masks)
self.mask_dict.update({op_name: masks})
self.epoch_pruned_layers.add(layer.name)
......@@ -174,22 +174,17 @@ class FPGMPruner(Pruner):
return masks
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.shape) >= 3
assert weight.shape[0] * weight.shape[1] > 2
dist_list = []
for in_i in range(weight.shape[0]):
for out_i in range(weight.shape[1]):
dist_sum = self._get_distance_sum(weight, in_i, out_i)
dist_list.append((dist_sum, (in_i, out_i)))
for out_i in range(weight.shape[0]):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, in_idx, out_idx):
w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1]))
anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0], 1, 1])
x = w - anchor_w
x = tf.math.reduce_sum((x*x), (-2, -1))
def _get_distance_sum(self, weight, out_idx):
anchor_w = tf.tile(tf.expand_dims(weight[out_idx], 0), [weight.shape[0], 1])
x = weight - anchor_w
x = tf.math.reduce_sum((x*x), -1)
x = tf.math.sqrt(x)
return tf.math.reduce_sum(x)
......
......@@ -215,9 +215,9 @@ class FPGMPruner(Pruner):
masks = torch.ones(weight.size()).type_as(weight)
try:
num_kernels = weight.size(0) * weight.size(1)
num_prune = int(num_kernels * config.get('sparsity'))
if num_kernels < 2 or num_prune < 1:
num_filters = weight.size(0)
num_prune = int(num_filters * config.get('sparsity'))
if num_filters < 2 or num_prune < 1:
return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
......@@ -233,13 +233,12 @@ class FPGMPruner(Pruner):
dist_list = []
for out_i in range(weight.size(0)):
for in_i in range(weight.size(1)):
dist_sum = self._get_distance_sum(weight, out_i, in_i)
dist_list.append((dist_sum, (out_i, in_i)))
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx, in_idx):
def _get_distance_sum(self, weight, out_idx):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
......@@ -257,24 +256,18 @@ class FPGMPruner(Pruner):
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
in_idx: int
input channel index of specified filter
Returns
-------
float32
The total distance
"""
logger.debug('weight size: %s', weight.size())
if len(weight.size()) == 4: # Conv2d
w = weight.view(-1, weight.size(-2), weight.size(-1))
anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2))
elif len(weight.size()) == 3: # Conv1d
w = weight.view(-1, weight.size(-1))
anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1))
else:
raise RuntimeError('unsupported layer type')
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum((-2, -1))
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
......
......@@ -54,14 +54,8 @@ def tf2(func):
return test_tf2_func
k1 = [[1] * 3] * 3
k2 = [[2] * 3] * 3
k3 = [[3] * 3] * 3
k4 = [[4] * 3] * 3
k5 = [[5] * 3] * 3
w = [[k1, k2, k3, k4, k5]] * 10
# for fpgm filter pruner test
w = np.array([[[[i+1]*3]*3]*5 for i in range(10)])
class CompressorTestCase(TestCase):
......@@ -92,16 +86,16 @@ class CompressorTestCase(TestCase):
def test_torch_fpgm_pruner(self):
"""
With filters(kernels) defined as above (k1 - k5), it is obvious that k3 is the Geometric Median
With filters(kernels) weights defined as above (w), it is obvious that w[4] and w[5] is the Geometric Median
which minimize the total geometric distance by defination of Geometric Median in this paper:
Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
https://arxiv.org/pdf/1811.00250.pdf
So if sparsity is 0.2, the expected masks should mask out all k3, this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))`
So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
`all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))`
If sparsity is 0.6, the expected masks should mask out all k2, k3, k4, this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))`
If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
`all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))`
"""
model = TorchModel()
......@@ -111,12 +105,12 @@ class CompressorTestCase(TestCase):
model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))
assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))
assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2
def test_tf_fpgm_pruner(self):
......@@ -130,16 +124,15 @@ class CompressorTestCase(TestCase):
layer = tf_compressor.compressor.LayerInfo(model.layers[2])
masks = pruner.calc_mask(layer, config_list[0]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3])
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
assert all(masks.sum((0, 2, 3)) == np.array([90., 90., 0., 90., 90.]))
assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3])
assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.]))
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
assert all(masks.sum((1)) == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
def test_torch_l1filter_pruner(self):
"""
......@@ -208,6 +201,5 @@ class CompressorTestCase(TestCase):
assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment