"git@developer.sourcefind.cn:OpenDAS/deepspeed.git" did not exist on "1661e83032d32bb61451ee2f4aaf31e68c901bef"
matrix.cpp 5.44 KB
Newer Older
1
2
// Copyright (C) 2013  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
3

4
#include <dlib/python.h>
5
#include <dlib/matrix.h>
6
#include <dlib/string.h>
7
#include <pybind11/pybind11.h>
8
9

using namespace dlib;
10
namespace py = pybind11;
Davis King's avatar
Davis King committed
11
12
using std::string;
using std::ostringstream;
13
14
15
16
17
18
19
20


void matrix_set_size(matrix<double>& m, long nr, long nc)
{
    m.set_size(nr,nc);
    m = 0;
}

21
22
23
24
25
26
27
28
string matrix_double__repr__(matrix<double>& c)
{
    ostringstream sout;
    sout << "< dlib.matrix containing: \n";
    sout << c;
    return trim(sout.str()) + " >";
}

29
30
31
32
string matrix_double__str__(matrix<double>& c)
{
    ostringstream sout;
    sout << c;
33
    return trim(sout.str());
34
35
}

36
std::shared_ptr<matrix<double> > make_matrix_from_size(long nr, long nc)
37
{
38
39
    if (nr < 0 || nc < 0)
    {
40
41
42
        PyErr_SetString( PyExc_IndexError, "Input dimensions can't be negative."
        );
        throw py::error_already_set();
43
    }
44
    auto temp = std::make_shared<matrix<double>>(nr,nc);
45
46
47
48
49
    *temp = 0;
    return temp;
}


50
std::shared_ptr<matrix<double> > from_object(py::object obj)
51
{
52
    py::tuple s = obj.attr("shape").cast<py::tuple>();
53
54
    if (len(s) != 2)
    {
55
56
57
        PyErr_SetString( PyExc_IndexError, "Input must be a matrix or some kind of 2D array."
        );
        throw py::error_already_set();
58
59
    }

60
61
62
    const long nr = s[0].cast<long>();
    const long nc = s[1].cast<long>();
    auto temp = std::make_shared<matrix<double>>(nr,nc);
63
64
65
66
    for ( long r = 0; r < nr; ++r)
    {
        for (long c = 0; c < nc; ++c)
        {
67
            (*temp)(r,c) = obj[py::make_tuple(r,c)].cast<double>();
68
69
70
71
72
        }
    }
    return temp;
}

73
std::shared_ptr<matrix<double> > from_list(py::list l)
74
{
75
76
    const long nr = py::len(l);
    if (py::isinstance<py::list>(l[0]))
77
    {
78
        const long nc = py::len(l[0]);
79
80
        // make sure all the other rows have the same length
        for (long r = 1; r < nr; ++r)
81
            pyassert(py::len(l[r]) == nc, "All rows of a matrix must have the same number of columns.");
82

83
        auto temp = std::make_shared<matrix<double>>(nr,nc);
84
85
86
87
        for ( long r = 0; r < nr; ++r)
        {
            for (long c = 0; c < nc; ++c)
            {
88
                (*temp)(r,c) = l[r].cast<py::list>()[c].cast<double>();
89
90
91
92
93
94
95
            }
        }
        return temp;
    }
    else
    {
        // In this case we treat it like a column vector
96
        auto temp = std::make_shared<matrix<double>>(nr,1);
97
98
        for ( long r = 0; r < nr; ++r)
        {
99
            (*temp)(r) = l[r].cast<double>();
100
101
102
103
104
        }
        return temp;
    }
}

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
long matrix_double__len__(matrix<double>& c)
{
    return c.nr();
}

struct mat_row
{
    mat_row() : data(0),size(0) {}
    mat_row(double* data_, long size_) : data(data_),size(size_) {}
    double* data;
    long size;
};

void mat_row__setitem__(mat_row& c, long p, double val)
{
    if (p < 0) {
        p = c.size + p; // negative index
    }
    if (p > c.size-1) {
124
125
126
        PyErr_SetString( PyExc_IndexError, "3 index out of range"
        );
        throw py::error_already_set();
127
128
129
130
131
132
133
134
135
136
137
138
    }
    c.data[p] = val;
}


string mat_row__str__(mat_row& c)
{
    ostringstream sout;
    sout << mat(c.data,1, c.size);
    return sout.str();
}

139
140
141
142
143
144
145
string mat_row__repr__(mat_row& c)
{
    ostringstream sout;
    sout << "< matrix row: " << mat(c.data,1, c.size);
    return trim(sout.str()) + " >";
}

146
147
148
149
150
151
152
153
154
155
156
long mat_row__len__(mat_row& m)
{
    return m.size;
}

double mat_row__getitem__(mat_row& m, long r)
{
    if (r < 0) {
        r = m.size + r; // negative index
    }
    if (r > m.size-1 || r < 0) {
157
158
159
        PyErr_SetString( PyExc_IndexError, "1 index out of range"
        );
        throw py::error_already_set();
160
161
162
163
164
165
166
167
168
169
170
    }
    return m.data[r];
}

mat_row matrix_double__getitem__(matrix<double>& m, long r)
{
    if (r < 0) {
        r = m.nr() + r; // negative index
    }
    if (r > m.nr()-1 || r < 0) {
        PyErr_SetString( PyExc_IndexError, (string("2 index out of range, got ") + cast_to_string(r)).c_str()
171
172
        );
        throw py::error_already_set();
173
174
175
176
177
    }
    return mat_row(&m(r,0),m.nc());
}


178
py::tuple get_matrix_size(matrix<double>& m)
179
{
180
    return py::make_tuple(m.nr(), m.nc());
181
182
}

183
void bind_matrix(py::module& m)
184
{
185
    py::class_<mat_row>(m, "_row")
186
        .def("__len__", &mat_row__len__)
187
        .def("__repr__", &mat_row__repr__)
188
189
190
191
        .def("__str__", &mat_row__str__)
        .def("__setitem__", &mat_row__setitem__)
        .def("__getitem__", &mat_row__getitem__);

192
193
194
195
196
197
198
199
    py::class_<matrix<double>, std::shared_ptr<matrix<double>>>(m, "matrix",
        "This object represents a dense 2D matrix of floating point numbers."
        "Moreover, it binds directly to the C++ type dlib::matrix<double>.")
        .def(py::init<>())
        .def(py::init(&from_list))
        .def(py::init(&from_object))
        .def(py::init(&make_matrix_from_size))
        .def("set_size", &matrix_set_size, py::arg("rows"), py::arg("cols"), "Set the size of the matrix to the given number of rows and columns.")
200
        .def("__repr__", &matrix_double__repr__)
201
        .def("__str__", &matrix_double__str__)
Davis King's avatar
Davis King committed
202
203
        .def("nr", &matrix<double>::nr, "Return the number of rows in the matrix.")
        .def("nc", &matrix<double>::nc, "Return the number of columns in the matrix.")
204
        .def("__len__", &matrix_double__len__)
205
206
207
        .def("__getitem__", &matrix_double__getitem__, py::keep_alive<0,1>())
        .def_property_readonly("shape", &get_matrix_size)
        .def(py::pickle(&getstate<matrix<double>>, &setstate<matrix<double>>));
208
}