test_scaling.py 860 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import pytest
import torch

from nni.algorithms.compression.v2.pytorch.utils.scaling import Scaling


def test_scaling():
11
    data = torch.tensor([_ for _ in range(100)], dtype=torch.float32).reshape(10, 10)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

    scaler = Scaling([5], kernel_padding_mode='front')
    shrinked_data = scaler.shrink(data)
    assert list(shrinked_data.shape) == [10, 2]
    expanded_data = scaler.expand(data, [10, 50])
    assert list(expanded_data.shape) == [10, 50]

    scaler = Scaling([5, 5], kernel_padding_mode='back')
    shrinked_data = scaler.shrink(data)
    assert list(shrinked_data.shape) == [2, 2]
    expanded_data = scaler.expand(data, [50, 50, 10])
    assert list(expanded_data.shape) == [50, 50, 10]

    scaler.validate([10, 10, 10])


if __name__ == '__main__':
    test_scaling()