test_matrix.py 2.47 KB
Newer Older
1
2
3
4
5
6
from dlib import matrix
try:
    import cPickle as pickle  # Use cPickle on Python 2.7
except ImportError:
    import pickle
from pytest import raises
7
8
9
10
11
12

try:
    import numpy
except ImportError:
    # Just skip these tests if numpy isn't installed
    exit(0)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


def test_matrix_empty_init():
    m = matrix()
    assert m.nr() == 0
    assert m.nc() == 0
    assert m.shape == (0, 0)
    assert len(m) == 0
    assert repr(m) == "< dlib.matrix containing: >"
    assert str(m) == ""


def test_matrix_from_list():
    m = matrix([[0, 1, 2],
                [3, 4, 5],
                [6, 7, 8]])
    assert m.nr() == 3
    assert m.nc() == 3
    assert m.shape == (3, 3)
    assert len(m) == 3
    assert repr(m) == "< dlib.matrix containing: \n0 1 2 \n3 4 5 \n6 7 8 >"
    assert str(m) == "0 1 2 \n3 4 5 \n6 7 8"

    deser = pickle.loads(pickle.dumps(m, 2))

    for row in range(3):
        for col in range(3):
            assert m[row][col] == deser[row][col]


def test_matrix_from_list_with_invalid_rows():
    with raises(ValueError):
        matrix([[0, 1, 2],
                [3, 4],
                [5, 6, 7]])


def test_matrix_from_list_as_column_vector():
    m = matrix([0, 1, 2])
    assert m.nr() == 3
    assert m.nc() == 1
    assert m.shape == (3, 1)
    assert len(m) == 3
    assert repr(m) == "< dlib.matrix containing: \n0 \n1 \n2 >"
    assert str(m) == "0 \n1 \n2"


def test_matrix_from_object_with_2d_shape():
    m1 = numpy.array([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])
    m = matrix(m1)
    assert m.nr() == 3
    assert m.nc() == 3
    assert m.shape == (3, 3)
    assert len(m) == 3
    assert repr(m) == "< dlib.matrix containing: \n0 1 2 \n3 4 5 \n6 7 8 >"
    assert str(m) == "0 1 2 \n3 4 5 \n6 7 8"


def test_matrix_from_object_without_2d_shape():
    with raises(IndexError):
        m1 = numpy.array([0, 1, 2])
        matrix(m1)


def test_matrix_from_object_without_shape():
    with raises(AttributeError):
        matrix("invalid")


def test_matrix_set_size():
    m = matrix()
    m.set_size(5, 5)

    assert m.nr() == 5
    assert m.nc() == 5
    assert m.shape == (5, 5)
    assert len(m) == 5
    assert repr(m) == "< dlib.matrix containing: \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 >"
    assert str(m) == "0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0"

    deser = pickle.loads(pickle.dumps(m, 2))

    for row in range(5):
        for col in range(5):
            assert m[row][col] == deser[row][col]