matrix.h 61 KB
Newer Older
1
// Copyright (C) 2006  Davis E. King (davis@dlib.net)
2
3
4
5
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_MATRIx_
#define DLIB_MATRIx_

6
#include "matrix_exp.h"
7
8
9
10
11
12
13
#include "matrix_abstract.h"
#include "../algs.h"
#include "../serialize.h"
#include "../enable_if.h"
#include <sstream>
#include <algorithm>
#include "../memory_manager.h"
14
#include "../is_kind.h"
15
#include "matrix_data_layout.h"
16
#include "matrix_assign_fwd.h"
17
#include "matrix_op.h"
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

#ifdef _MSC_VER
// Disable the following warnings for Visual Studio

// This warning is:
//    "warning C4355: 'this' : used in base member initializer list"
// Which we get from this code but it is not an error so I'm turning this
// warning off and then turning it back on at the end of the file.
#pragma warning(disable : 4355)

#endif

namespace dlib
{

33
// ----------------------------------------------------------------------------------------
34
35
36
37
38
39
40
41
42
43
44
45

    // This template will perform the needed loop for element multiplication using whichever
    // dimension is provided as a compile time constant (if one is at all).
    template <
        typename LHS,
        typename RHS,
        long lhs_nc = LHS::NC,
        long rhs_nr = RHS::NR
        >
    struct matrix_multiply_helper 
    {
        typedef typename LHS::type type;
46
        template <typename RHS_, typename LHS_>
47
        inline const static type  eval (
48
49
            const RHS_& rhs,
            const LHS_& lhs,
50
51
            const long r, 
            const long c
52
53
        )  
        { 
54
55
            type temp = lhs(r,0)*rhs(0,c);
            for (long i = 1; i < rhs.nr(); ++i)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            {
                temp += lhs(r,i)*rhs(i,c);
            }
            return temp;
        }
    };

    template <
        typename LHS,
        typename RHS,
        long lhs_nc 
        >
    struct matrix_multiply_helper <LHS,RHS,lhs_nc,0>
    {
        typedef typename LHS::type type;
71
        template <typename RHS_, typename LHS_>
72
        inline const static type  eval (
73
74
            const RHS_& rhs,
            const LHS_& lhs,
75
76
            const long r, 
            const long c
77
78
        )  
        { 
79
80
            type temp = lhs(r,0)*rhs(0,c);
            for (long i = 1; i < lhs.nc(); ++i)
81
82
83
84
85
86
87
            {
                temp += lhs(r,i)*rhs(i,c);
            }
            return temp;
        }
    };

88
89
90
91
92
93
94
    template <typename LHS, typename RHS>
    class matrix_multiply_exp;

    template <typename LHS, typename RHS>
    struct matrix_traits<matrix_multiply_exp<LHS,RHS> >
    {
        typedef typename LHS::type type;
95
        typedef typename LHS::type const_ret_type;
96
        typedef typename LHS::mem_manager_type mem_manager_type;
97
        typedef typename LHS::layout_type layout_type;
98
99
100
        const static long NR = LHS::NR;
        const static long NC = RHS::NC;

101
102
103
#ifdef DLIB_USE_BLAS
        // if there are BLAS functions to be called then we want to make sure we
        // always evaluate any complex expressions so that the BLAS bindings can happen.
104
105
        const static bool lhs_is_costly = (LHS::cost > 2)&&(RHS::NC != 1 || LHS::cost >= 10000);
        const static bool rhs_is_costly = (RHS::cost > 2)&&(LHS::NR != 1 || RHS::cost >= 10000);
106
#else
107
108
        const static bool lhs_is_costly = (LHS::cost > 4)&&(RHS::NC != 1);
        const static bool rhs_is_costly = (RHS::cost > 4)&&(LHS::NR != 1);
109
#endif
110
111
112
113
114
115
116
117
118

        // Note that if we decide that one of the matrices is too costly we will evaluate it
        // into a temporary.  Doing this resets its cost back to 1.
        const static long lhs_cost = ((lhs_is_costly==true)? 1 : (LHS::cost));
        const static long rhs_cost = ((rhs_is_costly==true)? 1 : (RHS::cost));

        // The cost of evaluating an element of a matrix multiply is the cost of evaluating elements from
        // RHS and LHS times the number of rows/columns in the RHS/LHS matrix.  If we don't know the matrix
        // dimensions then just assume it is really large.
119
        const static long cost = ((tmax<LHS::NC,RHS::NR>::value!=0)? ((lhs_cost+rhs_cost)*tmax<LHS::NC,RHS::NR>::value):(10000));
120
121
122
123
124
    };

    template <typename T, bool is_ref> struct conditional_matrix_temp { typedef typename T::matrix_type type; };
    template <typename T> struct conditional_matrix_temp<T,true>      { typedef T& type; };

125
126
    template <
        typename LHS,
127
        typename RHS
128
        >
129
    class matrix_multiply_exp : public matrix_exp<matrix_multiply_exp<LHS,RHS> >
130
131
132
    {
        /*!
            REQUIREMENTS ON LHS AND RHS
133
                - must be matrix_exp objects.
134
135
136
        !*/
    public:

137
        typedef typename matrix_traits<matrix_multiply_exp>::type type;
138
        typedef typename matrix_traits<matrix_multiply_exp>::const_ret_type const_ret_type;
139
140
141
142
        typedef typename matrix_traits<matrix_multiply_exp>::mem_manager_type mem_manager_type;
        const static long NR = matrix_traits<matrix_multiply_exp>::NR;
        const static long NC = matrix_traits<matrix_multiply_exp>::NC;
        const static long cost = matrix_traits<matrix_multiply_exp>::cost;
143
        typedef typename matrix_traits<matrix_multiply_exp>::layout_type layout_type;
144
145
146
147


        const static bool lhs_is_costly = matrix_traits<matrix_multiply_exp>::lhs_is_costly;
        const static bool rhs_is_costly = matrix_traits<matrix_multiply_exp>::rhs_is_costly;
148
149
        const static bool either_is_costly = lhs_is_costly || rhs_is_costly;
        const static bool both_are_costly = lhs_is_costly && rhs_is_costly;
150
151
152
153

        typedef typename conditional_matrix_temp<const LHS,lhs_is_costly == false>::type LHS_ref_type;
        typedef typename conditional_matrix_temp<const RHS,rhs_is_costly == false>::type RHS_ref_type;

154
155
156
157
        // This constructor exists simply for the purpose of causing a compile time error if
        // someone tries to create an instance of this object with the wrong kind of objects.
        template <typename T1, typename T2>
        matrix_multiply_exp (T1,T2); 
158
159
160
161
162
163
164
165
166
167
168

        inline matrix_multiply_exp (
            const LHS& lhs_,
            const RHS& rhs_
        ) :
            lhs(lhs_),
            rhs(rhs_)
        {
            // You are trying to multiply two incompatible matrices together.  The number of columns 
            // in the matrix on the left must match the number of rows in the matrix on the right.
            COMPILE_TIME_ASSERT(LHS::NC == RHS::NR || LHS::NC*RHS::NR == 0);
169
            DLIB_ASSERT(lhs.nc() == rhs.nr() && lhs.size() > 0 && rhs.size() > 0, 
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
                "\tconst matrix_exp operator*(const matrix_exp& lhs, const matrix_exp& rhs)"
                << "\n\tYou are trying to multiply two incompatible matrices together"
                << "\n\tlhs.nr(): " << lhs.nr()
                << "\n\tlhs.nc(): " << lhs.nc()
                << "\n\trhs.nr(): " << rhs.nr()
                << "\n\trhs.nc(): " << rhs.nc()
                << "\n\t&lhs: " << &lhs 
                << "\n\t&rhs: " << &rhs 
                );

            // You can't multiply matrices together if they don't both contain the same type of elements.
            COMPILE_TIME_ASSERT((is_same_type<typename LHS::type, typename RHS::type>::value == true));
        }

        inline const type operator() (
185
186
            const long r, 
            const long c
187
188
189
190
191
        ) const 
        { 
            return matrix_multiply_helper<LHS,RHS>::eval(rhs,lhs,r,c);
        }

192
193
194
        inline const type operator() ( long i ) const 
        { return matrix_exp<matrix_multiply_exp>::operator()(i); }

195
196
197
198
199
200
        long nr (
        ) const { return lhs.nr(); }

        long nc (
        ) const { return rhs.nc(); }

201
        template <typename U>
202
        bool aliases (
203
            const matrix_exp<U>& item
204
205
        ) const { return lhs.aliases(item) || rhs.aliases(item); }

206
        template <typename U>
207
        bool destructively_aliases (
208
            const matrix_exp<U>& item
209
210
        ) const { return aliases(item); }

211
212
        LHS_ref_type lhs;
        RHS_ref_type rhs;
213
214
    };

215
216
    template < typename EXP1, typename EXP2 >
    inline const matrix_multiply_exp<EXP1, EXP2> operator* (
217
218
219
220
        const matrix_exp<EXP1>& m1,
        const matrix_exp<EXP2>& m2
    )
    {
221
        return matrix_multiply_exp<EXP1, EXP2>(m1.ref(), m2.ref());
222
223
    }

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    template <typename M, bool use_reference = true>
    class matrix_mul_scal_exp;

    // -------------------------

    // Now we declare some overloads that cause any scalar multiplications to percolate 
    // up and outside of any matrix multiplies.  Note that we are using the non-reference containing
    // mode of the matrix_mul_scal_exp object since we are passing in locally constructed matrix_multiply_exp 
    // objects.  So the matrix_mul_scal_exp object will contain copies of matrix_multiply_exp objects
    // rather than references to them.  This could result in extra matrix copies if the matrix_multiply_exp
    // decided it should evaluate any of its arguments.  So we also try to not apply this percolating operation 
    // if the matrix_multiply_exp would contain a fully evaluated copy of the original matrix_mul_scal_exp 
    // expression.
    // 
    // Also, the reason we want to apply this transformation in the first place is because it (1) makes
    // the expressions going into matrix multiply expressions simpler and (2) it makes it a lot more
240
    // straightforward to bind BLAS calls to matrix expressions involving scalar multiplies.
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    template < typename EXP1, typename EXP2 >
    inline const typename disable_if_c< matrix_multiply_exp<matrix_mul_scal_exp<EXP1>, matrix_mul_scal_exp<EXP2> >::both_are_costly ,      
                                        matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
        const matrix_mul_scal_exp<EXP1>& m1,
        const matrix_mul_scal_exp<EXP2>& m2
    )
    {
        typedef matrix_multiply_exp<EXP1, EXP2> exp1;
        typedef matrix_mul_scal_exp<exp1,false> exp2;
        return exp2(exp1(m1.m, m2.m), m1.s*m2.s);
    }

    template < typename EXP1, typename EXP2 >
    inline const typename disable_if_c< matrix_multiply_exp<matrix_mul_scal_exp<EXP1>, EXP2 >::lhs_is_costly ,      
                                      matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
        const matrix_mul_scal_exp<EXP1>& m1,
        const matrix_exp<EXP2>& m2
    )
    {
        typedef matrix_multiply_exp<EXP1, EXP2> exp1;
        typedef matrix_mul_scal_exp<exp1,false> exp2;
        return exp2(exp1(m1.m, m2.ref()), m1.s);
    }

    template < typename EXP1, typename EXP2 >
    inline const typename disable_if_c< matrix_multiply_exp<EXP1, matrix_mul_scal_exp<EXP2> >::rhs_is_costly ,      
                                      matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
        const matrix_exp<EXP1>& m1,
        const matrix_mul_scal_exp<EXP2>& m2
    )
    {
        typedef matrix_multiply_exp<EXP1, EXP2> exp1;
        typedef matrix_mul_scal_exp<exp1,false> exp2;
        return exp2(exp1(m1.ref(), m2.m), m2.s);
    }

277
// ----------------------------------------------------------------------------------------
278

279
280
    template <typename LHS, typename RHS>
    class matrix_add_exp;
281

282
283
    template <typename LHS, typename RHS>
    struct matrix_traits<matrix_add_exp<LHS,RHS> >
284
    {
285
        typedef typename LHS::type type;
286
        typedef typename LHS::type const_ret_type;
287
        typedef typename LHS::mem_manager_type mem_manager_type;
288
        typedef typename LHS::layout_type layout_type;
289
290
        const static long NR = (RHS::NR > LHS::NR) ? RHS::NR : LHS::NR;
        const static long NC = (RHS::NC > LHS::NC) ? RHS::NC : LHS::NC;
291
        const static long cost = LHS::cost+RHS::cost+1;
292
    };
293
294
295
296
297

    template <
        typename LHS,
        typename RHS
        >
298
    class matrix_add_exp : public matrix_exp<matrix_add_exp<LHS,RHS> >
299
300
301
    {
        /*!
            REQUIREMENTS ON LHS AND RHS
302
                - must be matrix_exp objects. 
303
304
        !*/
    public:
305
        typedef typename matrix_traits<matrix_add_exp>::type type;
306
        typedef typename matrix_traits<matrix_add_exp>::const_ret_type const_ret_type;
307
308
309
310
        typedef typename matrix_traits<matrix_add_exp>::mem_manager_type mem_manager_type;
        const static long NR = matrix_traits<matrix_add_exp>::NR;
        const static long NC = matrix_traits<matrix_add_exp>::NC;
        const static long cost = matrix_traits<matrix_add_exp>::cost;
311
        typedef typename matrix_traits<matrix_add_exp>::layout_type layout_type;
312

313
314
315
316
        // This constructor exists simply for the purpose of causing a compile time error if
        // someone tries to create an instance of this object with the wrong kind of objects.
        template <typename T1, typename T2>
        matrix_add_exp (T1,T2); 
317
318

        matrix_add_exp (
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
            const LHS& lhs_,
            const RHS& rhs_
        ) :
            lhs(lhs_),
            rhs(rhs_)
        {
            // You can only add matrices together if they both have the same number of rows and columns.
            COMPILE_TIME_ASSERT(LHS::NR == RHS::NR || LHS::NR == 0 || RHS::NR == 0);
            COMPILE_TIME_ASSERT(LHS::NC == RHS::NC || LHS::NC == 0 || RHS::NC == 0);
            DLIB_ASSERT(lhs.nc() == rhs.nc() &&
                   lhs.nr() == rhs.nr(), 
                "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)"
                << "\n\tYou are trying to add two incompatible matrices together"
                << "\n\tlhs.nr(): " << lhs.nr()
                << "\n\tlhs.nc(): " << lhs.nc()
                << "\n\trhs.nr(): " << rhs.nr()
                << "\n\trhs.nc(): " << rhs.nc()
                << "\n\t&lhs: " << &lhs 
                << "\n\t&rhs: " << &rhs 
                );

            // You can only add matrices together if they both contain the same types of elements.
            COMPILE_TIME_ASSERT((is_same_type<typename LHS::type, typename RHS::type>::value == true));
        }

        const type operator() (
            long r, 
            long c
        ) const { return lhs(r,c) + rhs(r,c); }

349
350
351
        inline const type operator() ( long i ) const 
        { return matrix_exp<matrix_add_exp>::operator()(i); }

352
        template <typename U>
353
        bool aliases (
354
            const matrix_exp<U>& item
355
356
        ) const { return lhs.aliases(item) || rhs.aliases(item); }

357
        template <typename U>
358
        bool destructively_aliases (
359
            const matrix_exp<U>& item
360
361
362
363
364
365
366
367
        ) const { return lhs.destructively_aliases(item) || rhs.destructively_aliases(item); }

        long nr (
        ) const { return lhs.nr(); }

        long nc (
        ) const { return lhs.nc(); }

368
369
        const LHS& lhs;
        const RHS& rhs;
370
371
372
373
374
375
    };

    template <
        typename EXP1,
        typename EXP2
        >
376
    inline const matrix_add_exp<EXP1, EXP2> operator+ (
377
378
379
380
        const matrix_exp<EXP1>& m1,
        const matrix_exp<EXP2>& m2
    )
    {
381
        return matrix_add_exp<EXP1, EXP2>(m1.ref(),m2.ref());
382
383
384
385
    }

// ----------------------------------------------------------------------------------------

386
387
388
389
390
391
392
    template <typename LHS, typename RHS>
    class matrix_subtract_exp;

    template <typename LHS, typename RHS>
    struct matrix_traits<matrix_subtract_exp<LHS,RHS> >
    {
        typedef typename LHS::type type;
393
        typedef typename LHS::type const_ret_type;
394
        typedef typename LHS::mem_manager_type mem_manager_type;
395
        typedef typename LHS::layout_type layout_type;
396
397
        const static long NR = (RHS::NR > LHS::NR) ? RHS::NR : LHS::NR;
        const static long NC = (RHS::NC > LHS::NC) ? RHS::NC : LHS::NC;
398
        const static long cost = LHS::cost+RHS::cost+1;
399
400
    };

401
402
403
404
    template <
        typename LHS,
        typename RHS
        >
405
    class matrix_subtract_exp : public matrix_exp<matrix_subtract_exp<LHS,RHS> >
406
407
408
    {
        /*!
            REQUIREMENTS ON LHS AND RHS
409
                - must be matrix_exp objects. 
410
411
        !*/
    public:
412
        typedef typename matrix_traits<matrix_subtract_exp>::type type;
413
        typedef typename matrix_traits<matrix_subtract_exp>::const_ret_type const_ret_type;
414
415
416
417
        typedef typename matrix_traits<matrix_subtract_exp>::mem_manager_type mem_manager_type;
        const static long NR = matrix_traits<matrix_subtract_exp>::NR;
        const static long NC = matrix_traits<matrix_subtract_exp>::NC;
        const static long cost = matrix_traits<matrix_subtract_exp>::cost;
418
        typedef typename matrix_traits<matrix_subtract_exp>::layout_type layout_type;
419

420
421
422
423
424

        // This constructor exists simply for the purpose of causing a compile time error if
        // someone tries to create an instance of this object with the wrong kind of objects.
        template <typename T1, typename T2>
        matrix_subtract_exp (T1,T2); 
425
426
427
428

        matrix_subtract_exp (
            const LHS& lhs_,
            const RHS& rhs_
429
        ) : 
430
431
432
433
434
435
436
437
438
            lhs(lhs_),
            rhs(rhs_)
        {
            // You can only subtract one matrix from another if they both have the same number of rows and columns.
            COMPILE_TIME_ASSERT(LHS::NR == RHS::NR || LHS::NR == 0 || RHS::NR == 0);
            COMPILE_TIME_ASSERT(LHS::NC == RHS::NC || LHS::NC == 0 || RHS::NC == 0);
            DLIB_ASSERT(lhs.nc() == rhs.nc() &&
                   lhs.nr() == rhs.nr(), 
                "\tconst matrix_exp operator-(const matrix_exp& lhs, const matrix_exp& rhs)"
Davis King's avatar
Davis King committed
439
                << "\n\tYou are trying to subtract two incompatible matrices"
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
                << "\n\tlhs.nr(): " << lhs.nr()
                << "\n\tlhs.nc(): " << lhs.nc()
                << "\n\trhs.nr(): " << rhs.nr()
                << "\n\trhs.nc(): " << rhs.nc()
                << "\n\t&lhs: " << &lhs 
                << "\n\t&rhs: " << &rhs 
                );

            // You can only subtract one matrix from another if they both contain elements of the same type.
            COMPILE_TIME_ASSERT((is_same_type<typename LHS::type, typename RHS::type>::value == true));
        }

        const type operator() (
            long r, 
            long c
        ) const { return lhs(r,c) - rhs(r,c); }

457
458
459
        inline const type operator() ( long i ) const 
        { return matrix_exp<matrix_subtract_exp>::operator()(i); }

460
        template <typename U>
461
        bool aliases (
462
            const matrix_exp<U>& item
463
464
        ) const { return lhs.aliases(item) || rhs.aliases(item); }

465
        template <typename U>
466
        bool destructively_aliases (
467
            const matrix_exp<U>& item
468
469
470
471
472
473
474
475
        ) const { return lhs.destructively_aliases(item) || rhs.destructively_aliases(item); }

        long nr (
        ) const { return lhs.nr(); }

        long nc (
        ) const { return lhs.nc(); }

476
477
        const LHS& lhs;
        const RHS& rhs;
478
479
480
481
482
483
    };

    template <
        typename EXP1,
        typename EXP2
        >
484
    inline const matrix_subtract_exp<EXP1, EXP2> operator- (
485
486
487
488
        const matrix_exp<EXP1>& m1,
        const matrix_exp<EXP2>& m2
    )
    {
489
        return matrix_subtract_exp<EXP1, EXP2>(m1.ref(),m2.ref());
490
491
492
493
    }

// ----------------------------------------------------------------------------------------

Davis King's avatar
Davis King committed
494
    template <typename M>
495
496
    class matrix_div_scal_exp;

Davis King's avatar
Davis King committed
497
498
    template <typename M>
    struct matrix_traits<matrix_div_scal_exp<M> >
499
500
    {
        typedef typename M::type type;
501
        typedef typename M::type const_ret_type;
502
        typedef typename M::mem_manager_type mem_manager_type;
503
        typedef typename M::layout_type layout_type;
504
505
506
507
508
        const static long NR = M::NR;
        const static long NC = M::NC;
        const static long cost = M::cost+1;
    };

509
    template <
Davis King's avatar
Davis King committed
510
        typename M
511
        >
Davis King's avatar
Davis King committed
512
    class matrix_div_scal_exp : public matrix_exp<matrix_div_scal_exp<M> >
513
514
515
    {
        /*!
            REQUIREMENTS ON M 
516
                - must be a matrix_exp object.
517
518
        !*/
    public:
519
        typedef typename matrix_traits<matrix_div_scal_exp>::type type;
520
        typedef typename matrix_traits<matrix_div_scal_exp>::const_ret_type const_ret_type;
521
522
523
524
        typedef typename matrix_traits<matrix_div_scal_exp>::mem_manager_type mem_manager_type;
        const static long NR = matrix_traits<matrix_div_scal_exp>::NR;
        const static long NC = matrix_traits<matrix_div_scal_exp>::NC;
        const static long cost = matrix_traits<matrix_div_scal_exp>::cost;
525
        typedef typename matrix_traits<matrix_div_scal_exp>::layout_type layout_type;
526

527
528
529
530

        // This constructor exists simply for the purpose of causing a compile time error if
        // someone tries to create an instance of this object with the wrong kind of objects.
        template <typename T1>
Davis King's avatar
Davis King committed
531
        matrix_div_scal_exp (T1, const type&); 
532

533
        matrix_div_scal_exp (
534
            const M& m_,
Davis King's avatar
Davis King committed
535
            const type& s_
536
537
538
539
540
541
542
543
544
545
        ) :
            m(m_),
            s(s_)
        {}

        const type operator() (
            long r, 
            long c
        ) const { return m(r,c)/s; }

546
547
548
        inline const type operator() ( long i ) const 
        { return matrix_exp<matrix_div_scal_exp>::operator()(i); }

549
        template <typename U>
550
        bool aliases (
551
            const matrix_exp<U>& item
552
553
        ) const { return m.aliases(item); }

554
        template <typename U>
555
        bool destructively_aliases (
556
            const matrix_exp<U>& item
557
558
559
560
561
562
563
564
        ) const { return m.destructively_aliases(item); }

        long nr (
        ) const { return m.nr(); }

        long nc (
        ) const { return m.nc(); }

565
        const M& m;
Davis King's avatar
Davis King committed
566
        const type s;
567
568
569
570
    };

    template <
        typename EXP,
571
        typename S
572
        >
573
    inline const typename enable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_div_scal_exp<EXP> >::type operator/ (
574
575
576
577
        const matrix_exp<EXP>& m,
        const S& s
    )
    {
578
        return matrix_div_scal_exp<EXP>(m.ref(),static_cast<typename EXP::type>(s));
579
580
581
582
    }

// ----------------------------------------------------------------------------------------

583
584
    template <typename M, bool use_reference >
    struct matrix_traits<matrix_mul_scal_exp<M,use_reference> >
585
586
    {
        typedef typename M::type type;
587
        typedef typename M::type const_ret_type;
588
        typedef typename M::mem_manager_type mem_manager_type;
589
        typedef typename M::layout_type layout_type;
590
591
592
593
594
        const static long NR = M::NR;
        const static long NC = M::NC;
        const static long cost = M::cost+1;
    };

595
596
597
598
    template <typename T, bool is_ref> struct conditional_reference { typedef T type; };
    template <typename T> struct conditional_reference<T,true>      { typedef T& type; };


599
    template <
600
601
        typename M,
        bool use_reference
602
        >
603
    class matrix_mul_scal_exp : public matrix_exp<matrix_mul_scal_exp<M,use_reference> >
604
605
606
    {
        /*!
            REQUIREMENTS ON M 
607
                - must be a matrix_exp object.
608
609
610

        !*/
    public:
611
        typedef typename matrix_traits<matrix_mul_scal_exp>::type type;
612
        typedef typename matrix_traits<matrix_mul_scal_exp>::const_ret_type const_ret_type;
613
614
615
616
        typedef typename matrix_traits<matrix_mul_scal_exp>::mem_manager_type mem_manager_type;
        const static long NR = matrix_traits<matrix_mul_scal_exp>::NR;
        const static long NC = matrix_traits<matrix_mul_scal_exp>::NC;
        const static long cost = matrix_traits<matrix_mul_scal_exp>::cost;
617
        typedef typename matrix_traits<matrix_mul_scal_exp>::layout_type layout_type;
618

619
620
        // You aren't allowed to multiply a matrix of matrices by a scalar.   
        COMPILE_TIME_ASSERT(is_matrix<type>::value == false);
621
622
623
624

        // This constructor exists simply for the purpose of causing a compile time error if
        // someone tries to create an instance of this object with the wrong kind of objects.
        template <typename T1>
Davis King's avatar
Davis King committed
625
        matrix_mul_scal_exp (T1, const type&); 
626
627

        matrix_mul_scal_exp (
628
            const M& m_,
Davis King's avatar
Davis King committed
629
            const type& s_
630
631
632
633
634
635
636
637
638
639
        ) :
            m(m_),
            s(s_)
        {}

        const type operator() (
            long r, 
            long c
        ) const { return m(r,c)*s; }

640
641
642
        inline const type operator() ( long i ) const 
        { return matrix_exp<matrix_mul_scal_exp>::operator()(i); }

643
        template <typename U>
644
        bool aliases (
645
            const matrix_exp<U>& item
646
647
        ) const { return m.aliases(item); }

648
        template <typename U>
649
        bool destructively_aliases (
650
            const matrix_exp<U>& item
651
652
653
654
655
656
657
658
        ) const { return m.destructively_aliases(item); }

        long nr (
        ) const { return m.nr(); }

        long nc (
        ) const { return m.nc(); }

659
660
661
        typedef typename conditional_reference<const M,use_reference>::type M_ref_type;

        M_ref_type m;
Davis King's avatar
Davis King committed
662
        const type s;
663
664
665
666
667
668
    };

    template <
        typename EXP,
        typename S 
        >
Davis King's avatar
Davis King committed
669
    inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
670
671
672
673
        const matrix_exp<EXP>& m,
        const S& s
    )
    {
674
675
        typedef typename EXP::type type;
        return matrix_mul_scal_exp<EXP>(m.ref(),static_cast<type>(s));
676
677
    }

678
679
680
681
682
683
684
685
686
687
    template <
        typename EXP,
        typename S,
        bool B
        >
    inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
        const matrix_mul_scal_exp<EXP,B>& m,
        const S& s
    )
    {
688
689
        typedef typename EXP::type type;
        return matrix_mul_scal_exp<EXP>(m.m,static_cast<type>(s)*m.s);
690
691
    }

692
693
694
695
    template <
        typename EXP,
        typename S 
        >
Davis King's avatar
Davis King committed
696
    inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
697
698
699
700
        const S& s,
        const matrix_exp<EXP>& m
    )
    {
701
702
        typedef typename EXP::type type;
        return matrix_mul_scal_exp<EXP>(m.ref(),static_cast<type>(s));
703
704
705
    }

    template <
706
707
708
        typename EXP,
        typename S,
        bool B
709
        >
710
711
712
    inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
        const S& s,
        const matrix_mul_scal_exp<EXP,B>& m
713
714
    )
    {
715
716
        typedef typename EXP::type type;
        return matrix_mul_scal_exp<EXP>(m.m,static_cast<type>(s)*m.s);
717
718
719
    }

    template <
720
721
        typename EXP ,
        typename S
722
        >
723
    inline const typename disable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_mul_scal_exp<EXP> >::type operator/ (
724
        const matrix_exp<EXP>& m,
725
        const S& s
726
727
    )
    {
728
729
730
        typedef typename EXP::type type;
        const type one = 1;
        return matrix_mul_scal_exp<EXP>(m.ref(),one/static_cast<type>(s));
731
732
733
    }

    template <
734
735
736
        typename EXP,
        bool B,
        typename S
737
        >
738
739
740
    inline const typename disable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_mul_scal_exp<EXP> >::type operator/ (
        const matrix_mul_scal_exp<EXP,B>& m,
        const S& s
741
742
    )
    {
743
744
        typedef typename EXP::type type;
        return matrix_mul_scal_exp<EXP>(m.m,m.s/static_cast<type>(s));
745
746
    }

747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
// ----------------------------------------------------------------------------------------

    template <typename M>
    struct op_s_div_m : basic_op_m<M> 
    {
        typedef typename M::type type;

        op_s_div_m( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}

        const type s;

        const static long cost = M::cost+1;
        typedef const typename M::type const_ret_type;
        const_ret_type apply (long r, long c) const
        { 
            return s/this->m(r,c);
        }
    };

    template <
        typename EXP,
        typename S
        >
    const typename disable_if<is_matrix<S>, matrix_op<op_s_div_m<EXP> > >::type operator/ (
        const S& val,
        const matrix_exp<EXP>& m
    )
    {
        typedef typename EXP::type type;

        typedef op_s_div_m<EXP> op;
        return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
    }

// ----------------------------------------------------------------------------------------

783
784
785
    template <
        typename EXP
        >
Davis King's avatar
Davis King committed
786
    inline const matrix_mul_scal_exp<EXP> operator- (
787
788
789
        const matrix_exp<EXP>& m
    )
    {
Davis King's avatar
Davis King committed
790
        return matrix_mul_scal_exp<EXP>(m.ref(),-1);
791
792
    }

793
794
795
796
797
798
799
800
801
802
803
    template <
        typename EXP,
        bool B
        >
    inline const matrix_mul_scal_exp<EXP> operator- (
        const matrix_mul_scal_exp<EXP,B>& m
    )
    {
        return matrix_mul_scal_exp<EXP>(m.m,-1*m.s);
    }

804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
// ----------------------------------------------------------------------------------------

    template <typename M>
    struct op_add_scalar : basic_op_m<M> 
    {
        typedef typename M::type type;

        op_add_scalar( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}

        const type s;

        const static long cost = M::cost+1;
        typedef const typename M::type const_ret_type;
        const_ret_type apply (long r, long c) const
        { 
            return this->m(r,c) + s;
        }
    };

    template <
        typename EXP,
        typename T
        >
    const typename disable_if<is_matrix<T>, matrix_op<op_add_scalar<EXP> > >::type operator+ (
        const matrix_exp<EXP>& m,
        const T& val
    )
    {
        typedef typename EXP::type type;

        typedef op_add_scalar<EXP> op;
        return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
    }

    template <
        typename EXP,
        typename T
        >
    const typename disable_if<is_matrix<T>, matrix_op<op_add_scalar<EXP> > >::type operator+ (
        const T& val,
        const matrix_exp<EXP>& m
    )
    {
        typedef typename EXP::type type;

        typedef op_add_scalar<EXP> op;
        return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
    }

// ----------------------------------------------------------------------------------------

    template <typename M>
    struct op_subl_scalar : basic_op_m<M> 
    {
        typedef typename M::type type;

        op_subl_scalar( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}

        const type s;

        const static long cost = M::cost+1;
        typedef const typename M::type const_ret_type;
        const_ret_type apply (long r, long c) const
        { 
            return s - this->m(r,c) ;
        }
    };

    template <
        typename EXP,
        typename T
        >
    const typename disable_if<is_matrix<T>, matrix_op<op_subl_scalar<EXP> > >::type operator- (
        const T& val,
        const matrix_exp<EXP>& m
    )
    {
        typedef typename EXP::type type;

        typedef op_subl_scalar<EXP> op;
        return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
    }

// ----------------------------------------------------------------------------------------

    template <typename M>
    struct op_subr_scalar : basic_op_m<M> 
    {
        typedef typename M::type type;

        op_subr_scalar( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}

        const type s;

        const static long cost = M::cost+1;
        typedef const typename M::type const_ret_type;
        const_ret_type apply (long r, long c) const
        { 
            return this->m(r,c) - s;
        }
    };

    template <
        typename EXP,
        typename T
        >
    const typename disable_if<is_matrix<T>, matrix_op<op_subr_scalar<EXP> > >::type operator- (
        const matrix_exp<EXP>& m,
        const T& val
    )
    {
        typedef typename EXP::type type;

        typedef op_subr_scalar<EXP> op;
        return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
    }

921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
// ----------------------------------------------------------------------------------------

    template <
        typename EXP1,
        typename EXP2
        >
    bool operator== (
        const matrix_exp<EXP1>& m1,
        const matrix_exp<EXP2>& m2
    )
    {
        if (m1.nr() == m2.nr() && m1.nc() == m2.nc())
        {
            for (long r = 0; r < m1.nr(); ++r)
            {
                for (long c = 0; c < m1.nc(); ++c)
                {
                    if (m1(r,c) != m2(r,c))
                        return false;
                }
            }
            return true;
        }
        return false;
    }

    template <
        typename EXP1,
        typename EXP2
        >
    inline bool operator!= (
        const matrix_exp<EXP1>& m1,
        const matrix_exp<EXP2>& m2
    ) { return !(m1 == m2); }

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

    template <
        typename T,
        long num_rows,
        long num_cols,
964
965
        typename mem_manager,
        typename layout
966
        >
967
    struct matrix_traits<matrix<T,num_rows, num_cols, mem_manager, layout> >
968
969
    {
        typedef T type;
970
        typedef const T& const_ret_type;
971
        typedef mem_manager mem_manager_type;
972
        typedef layout layout_type;
973
974
        const static long NR = num_rows;
        const static long NC = num_cols;
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
        const static long cost = 1;

    };

    template <
        typename T,
        long num_rows,
        long num_cols,
        typename mem_manager,
        typename layout
        >
    class matrix : public matrix_exp<matrix<T,num_rows,num_cols, mem_manager,layout> > 
    {

        COMPILE_TIME_ASSERT(num_rows >= 0 && num_cols >= 0); 
990

991
992
    public:
        typedef typename matrix_traits<matrix>::type type;
993
        typedef typename matrix_traits<matrix>::const_ret_type const_ret_type;
994
995
996
997
998
        typedef typename matrix_traits<matrix>::mem_manager_type mem_manager_type;
        typedef typename matrix_traits<matrix>::layout_type layout_type;
        const static long NR = matrix_traits<matrix>::NR;
        const static long NC = matrix_traits<matrix>::NC;
        const static long cost = matrix_traits<matrix>::cost;
999
1000
        typedef T*          iterator;       
        typedef const T*    const_iterator; 
1001

1002
        matrix () 
1003
1004
1005
1006
1007
        {
        }

        explicit matrix (
            long length 
1008
        ) 
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
        {
            // This object you are trying to call matrix(length) on is not a column or 
            // row vector.
            COMPILE_TIME_ASSERT(NR == 1 || NC == 1);
            DLIB_ASSERT( length >= 0, 
                "\tmatrix::matrix(length)"
                << "\n\tlength must be at least 0"
                << "\n\tlength: " << length 
                << "\n\tNR:     " << NR 
                << "\n\tNC:     " << NC 
                << "\n\tthis:   " << this
                );

            if (NR == 1)
            {
                DLIB_ASSERT(NC == 0 || NC == length,
                    "\tmatrix::matrix(length)"
Davis King's avatar
Davis King committed
1026
                    << "\n\tSince this is a statically sized matrix length must equal NC"
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
                    << "\n\tlength: " << length 
                    << "\n\tNR:     " << NR 
                    << "\n\tNC:     " << NC 
                    << "\n\tthis:   " << this
                    );

                data.set_size(1,length);
            }
            else
            {
                DLIB_ASSERT(NR == 0 || NR == length,
                    "\tvoid matrix::set_size(length)"
Davis King's avatar
Davis King committed
1039
                    << "\n\tSince this is a statically sized matrix length must equal NR"
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
                    << "\n\tlength: " << length 
                    << "\n\tNR:     " << NR 
                    << "\n\tNC:     " << NC 
                    << "\n\tthis:   " << this
                    );

                data.set_size(length,1);
            }
        }

        matrix (
            long rows,
            long cols 
1053
        )  
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
        {
            DLIB_ASSERT( (NR == 0 || NR == rows) && ( NC == 0 || NC == cols) && 
                    rows >= 0 && cols >= 0, 
                "\tvoid matrix::matrix(rows, cols)"
                << "\n\tYou have supplied conflicting matrix dimensions"
                << "\n\trows: " << rows
                << "\n\tcols: " << cols
                << "\n\tNR:   " << NR 
                << "\n\tNC:   " << NC 
                );
            data.set_size(rows,cols);
        }

        template <typename EXP>
        matrix (
            const matrix_exp<EXP>& m
1070
        )
1071
        {
1072
1073
1074
1075
1076
            // You get an error on this line if the matrix m contains a type that isn't
            // the same as the type contained in the target matrix.
            COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true) ||
                                (is_matrix<typename EXP::type>::value == true));

1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
            // The matrix you are trying to assign m to is a statically sized matrix and 
            // m's dimensions don't match that of *this. 
            COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0);
            COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0);
            DLIB_ASSERT((NR == 0 || NR == m.nr()) && (NC == 0 || NC == m.nc()), 
                "\tmatrix& matrix::matrix(const matrix_exp& m)"
                << "\n\tYou are trying to assign a dynamically sized matrix to a statically sized matrix with the wrong size"
                << "\n\tNR:     " << NR
                << "\n\tNC:     " << NC
                << "\n\tm.nr(): " << m.nr()
                << "\n\tm.nc(): " << m.nc()
                << "\n\tthis:   " << this
                );

            data.set_size(m.nr(),m.nc());

1093
            matrix_assign(*this, m);
1094
1095
1096
1097
        }

        matrix (
            const matrix& m
1098
        ) : matrix_exp<matrix>(*this) 
1099
1100
        {
            data.set_size(m.nr(),m.nc());
1101
            matrix_assign(*this, m);
1102
1103
1104
        }

        template <typename U, size_t len>
1105
        explicit matrix (
1106
            U (&array)[len]
1107
        ) 
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        {
            COMPILE_TIME_ASSERT(NR*NC == len && len > 0);
            size_t idx = 0;
            for (long r = 0; r < NR; ++r)
            {
                for (long c = 0; c < NC; ++c)
                {
                    data(r,c) = static_cast<T>(array[idx]);
                    ++idx;
                }
            }
        }

        T& operator() (
            long r, 
            long c
        ) 
        { 
            DLIB_ASSERT(r < nr() && c < nc() &&
                   r >= 0 && c >= 0, 
                "\tT& matrix::operator(r,c)"
                << "\n\tYou must give a valid row and column"
                << "\n\tr:    " << r 
                << "\n\tc:    " << c
                << "\n\tnr(): " << nr()
                << "\n\tnc(): " << nc() 
                << "\n\tthis: " << this
                );
            return data(r,c); 
        }

        const T& operator() (
            long r, 
            long c
        ) const 
        { 
            DLIB_ASSERT(r < nr() && c < nc() &&
                   r >= 0 && c >= 0, 
                "\tconst T& matrix::operator(r,c)"
                << "\n\tYou must give a valid row and column"
                << "\n\tr:    " << r 
                << "\n\tc:    " << c
                << "\n\tnr(): " << nr()
                << "\n\tnc(): " << nc() 
                << "\n\tthis: " << this
                );
            return data(r,c);
        }

        T& operator() (
            long i
        ) 
        {
            // You can only use this operator on column vectors.
            COMPILE_TIME_ASSERT(NC == 1 || NC == 0 || NR == 1 || NR == 0);
            DLIB_ASSERT(nc() == 1 || nr() == 1, 
                "\tconst type matrix::operator(i)"
                << "\n\tYou can only use this operator on column or row vectors"
                << "\n\ti:    " << i
                << "\n\tnr(): " << nr()
                << "\n\tnc(): " << nc()
                << "\n\tthis: " << this
                );
1171
            DLIB_ASSERT( 0 <= i && i < size(), 
1172
1173
                "\tconst type matrix::operator(i)"
                << "\n\tYou must give a valid row/column number"
1174
1175
1176
                << "\n\ti:      " << i
                << "\n\tsize(): " << size()
                << "\n\tthis:   " << this
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
                );
            return data(i);
        }

        const T& operator() (
            long i
        ) const
        {
            // You can only use this operator on column vectors.
            COMPILE_TIME_ASSERT(NC == 1 || NC == 0 || NR == 1 || NR == 0);
            DLIB_ASSERT(nc() == 1 || nr() == 1, 
                "\tconst type matrix::operator(i)"
                << "\n\tYou can only use this operator on column or row vectors"
                << "\n\ti:    " << i
                << "\n\tnr(): " << nr()
                << "\n\tnc(): " << nc()
                << "\n\tthis: " << this
                );
1195
            DLIB_ASSERT( 0 <= i && i < size(), 
1196
1197
                "\tconst type matrix::operator(i)"
                << "\n\tYou must give a valid row/column number"
1198
1199
1200
                << "\n\ti:      " << i
                << "\n\tsize(): " << size()
                << "\n\tthis:   " << this
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
                );
            return data(i);
        }

        inline operator const type (
        ) const 
        {
            COMPILE_TIME_ASSERT(NC == 1 || NC == 0);
            COMPILE_TIME_ASSERT(NR == 1 || NR == 0);
            DLIB_ASSERT( nr() == 1 && nc() == 1 , 
                "\tmatrix::operator const type"
Davis King's avatar
Davis King committed
1212
                << "\n\tYou can only attempt to implicit convert a matrix to a scalar if"
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
                << "\n\tthe matrix is a 1x1 matrix"
                << "\n\tnr(): " << nr() 
                << "\n\tnc(): " << nc() 
                << "\n\tthis: " << this
                );
            return data(0);
        }

        void set_size (
            long rows,
            long cols
        )
        {
            DLIB_ASSERT( (NR == 0 || NR == rows) && ( NC == 0 || NC == cols) &&
                    rows >= 0 && cols >= 0, 
                "\tvoid matrix::set_size(rows, cols)"
                << "\n\tYou have supplied conflicting matrix dimensions"
                << "\n\trows: " << rows
                << "\n\tcols: " << cols
                << "\n\tNR:   " << NR 
                << "\n\tNC:   " << NC 
                << "\n\tthis: " << this
                );
            if (nr() != rows || nc() != cols)
                data.set_size(rows,cols);
        }

        void set_size (
            long length
        )
        {
            // This object you are trying to call set_size(length) on is not a column or 
            // row vector.
            COMPILE_TIME_ASSERT(NR == 1 || NC == 1);
            DLIB_ASSERT( length >= 0, 
                "\tvoid matrix::set_size(length)"
                << "\n\tlength must be at least 0"
                << "\n\tlength: " << length 
                << "\n\tNR:     " << NR 
                << "\n\tNC:     " << NC 
                << "\n\tthis:   " << this
                );

            if (NR == 1)
            {
                DLIB_ASSERT(NC == 0 || NC == length,
                    "\tvoid matrix::set_size(length)"
Davis King's avatar
Davis King committed
1260
                    << "\n\tSince this is a statically sized matrix length must equal NC"
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
                    << "\n\tlength: " << length 
                    << "\n\tNR:     " << NR 
                    << "\n\tNC:     " << NC 
                    << "\n\tthis:   " << this
                    );

                if (nc() != length)
                    data.set_size(1,length);
            }
            else
            {
                DLIB_ASSERT(NR == 0 || NR == length,
                    "\tvoid matrix::set_size(length)"
Davis King's avatar
Davis King committed
1274
                    << "\n\tSince this is a statically sized matrix length must equal NR"
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
                    << "\n\tlength: " << length 
                    << "\n\tNR:     " << NR 
                    << "\n\tNC:     " << NC 
                    << "\n\tthis:   " << this
                    );

                if (nr() != length)
                    data.set_size(length,1);
            }
        }

        long nr (
        ) const { return data.nr(); }

        long nc (
        ) const { return data.nc(); }

        long size (
        ) const { return data.nr()*data.nc(); }

        template <typename U, size_t len>
        matrix& operator= (
            U (&array)[len]
        )
        {
            COMPILE_TIME_ASSERT(NR*NC == len && len > 0);
            size_t idx = 0;
            for (long r = 0; r < NR; ++r)
            {
                for (long c = 0; c < NC; ++c)
                {
                    data(r,c) = static_cast<T>(array[idx]);
                    ++idx;
                }
            }
            return *this;
        }

        template <typename EXP>
        matrix& operator= (
            const matrix_exp<EXP>& m
        )
        {
1318
1319
1320
            // You get an error on this line if the matrix you are trying to 
            // assign m to is a statically sized matrix and  m's dimensions don't 
            // match that of *this. 
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
            COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0);
            COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0);
            DLIB_ASSERT((NR == 0 || nr() == m.nr()) && 
                   (NC == 0 || nc() == m.nc()), 
                "\tmatrix& matrix::operator=(const matrix_exp& m)"
                << "\n\tYou are trying to assign a dynamically sized matrix to a statically sized matrix with the wrong size"
                << "\n\tnr():   " << nr()
                << "\n\tnc():   " << nc()
                << "\n\tm.nr(): " << m.nr()
                << "\n\tm.nc(): " << m.nc()
                << "\n\tthis:   " << this
                );
1333
1334
1335
1336
1337

            // You get an error on this line if the matrix m contains a type that isn't
            // the same as the type contained in the target matrix.
            COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true) ||
                                (is_matrix<typename EXP::type>::value == true));
1338
1339
            if (m.destructively_aliases(*this) == false)
            {
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
                // This if statement is seemingly unnecessary since set_size() contains this
                // exact same if statement.  However, structuring the code this way causes
                // gcc to handle the way it inlines this function in a much more favorable way.
                if (data.nr() == m.nr() && data.nc() == m.nc())
                {
                    matrix_assign(*this, m);
                }
                else
                {
                    set_size(m.nr(),m.nc());
                    matrix_assign(*this, m);
                }
1352
1353
1354
            }
            else
            {
1355
1356
1357
                // we have to use a temporary matrix object here because
                // *this is aliased inside the matrix_exp m somewhere.
                matrix temp;
1358
                temp.set_size(m.nr(),m.nc());
1359
                matrix_assign(temp, m);
1360
                temp.swap(*this);
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
            }
            return *this;
        }

        template <typename EXP>
        matrix& operator += (
            const matrix_exp<EXP>& m
        )
        {
            // The matrix you are trying to assign m to is a statically sized matrix and 
            // m's dimensions don't match that of *this. 
            COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0);
            COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0);
            COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true));
1375
            if (nr() == m.nr() && nc() == m.nc())
1376
            {
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
                if (m.destructively_aliases(*this) == false)
                {
                    matrix_assign(*this, *this + m);
                }
                else
                {
                    // we have to use a temporary matrix object here because
                    // this->data is aliased inside the matrix_exp m somewhere.
                    matrix temp;
                    temp.set_size(m.nr(),m.nc());
                    matrix_assign(temp, *this + m);
                    temp.swap(*this);
                }
1390
1391
1392
            }
            else
            {
1393
                *this = m;
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
            }
            return *this;
        }


        template <typename EXP>
        matrix& operator -= (
            const matrix_exp<EXP>& m
        )
        {
            // The matrix you are trying to assign m to is a statically sized matrix and 
            // m's dimensions don't match that of *this. 
            COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0);
            COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0);
            COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true));
1409
            if (nr() == m.nr() && nc() == m.nc())
1410
            {
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
                if (m.destructively_aliases(*this) == false)
                {
                    matrix_assign(*this, *this - m);
                }
                else
                {
                    // we have to use a temporary matrix object here because
                    // this->data is aliased inside the matrix_exp m somewhere.
                    matrix temp;
                    temp.set_size(m.nr(),m.nc());
                    matrix_assign(temp, *this - m);
                    temp.swap(*this);
                }
1424
1425
1426
            }
            else
            {
1427
                *this = -m;
1428
1429
1430
1431
            }
            return *this;
        }

Davis King's avatar
Davis King committed
1432
1433
1434
1435
1436
1437
1438
1439
1440
        template <typename EXP>
        matrix& operator *= (
            const matrix_exp<EXP>& m
        )
        {
            *this = *this * m;
            return *this;
        }

1441
1442
1443
1444
1445
        matrix& operator += (
            const matrix& m
        )
        {
            const long size = m.nr()*m.nc();
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
            if (nr() == m.nr() && nc() == m.nc())
            {
                for (long i = 0; i < size; ++i)
                    data(i) += m.data(i);
            }
            else
            {
                set_size(m.nr(), m.nc());
                for (long i = 0; i < size; ++i)
                    data(i) = m.data(i);
            }
1457
1458
1459
1460
1461
1462
1463
1464
            return *this;
        }

        matrix& operator -= (
            const matrix& m
        )
        {
            const long size = m.nr()*m.nc();
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
            if (nr() == m.nr() && nc() == m.nc())
            {
                for (long i = 0; i < size; ++i)
                    data(i) -= m.data(i);
            }
            else
            {
                set_size(m.nr(), m.nc());
                for (long i = 0; i < size; ++i)
                    data(i) = -m.data(i);
            }
1476
1477
1478
            return *this;
        }

1479
        matrix& operator += (
1480
            const T val
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
        )
        {
            const long size = nr()*nc();
            for (long i = 0; i < size; ++i)
                data(i) += val;

            return *this;
        }

        matrix& operator -= (
1491
            const T val
1492
1493
1494
1495
1496
1497
1498
1499
1500
        )
        {
            const long size = nr()*nc();
            for (long i = 0; i < size; ++i)
                data(i) -= val;

            return *this;
        }

1501
        matrix& operator *= (
1502
            const T a
1503
1504
        )
        {
1505
            *this = *this * a;
1506
1507
1508
1509
            return *this;
        }

        matrix& operator /= (
1510
            const T a
1511
1512
        )
        {
1513
            *this = *this / a;
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
            return *this;
        }

        matrix& operator= (
            const matrix& m
        )
        {
            if (this != &m)
            {
                set_size(m.nr(),m.nc());
                const long size = m.nr()*m.nc();
                for (long i = 0; i < size; ++i)
                    data(i) = m.data(i);
            }
            return *this;
        }

        void swap (
            matrix& item
        )
        {
            data.swap(item.data);
        }

1538
        template <typename U>
1539
        bool aliases (
1540
            const matrix_exp<U>& 
1541
1542
1543
        ) const { return false; }

        bool aliases (
1544
            const matrix_exp<matrix<T,num_rows,num_cols, mem_manager,layout> >& item
1545
1546
        ) const { return (this == &item); }

1547
1548
1549
1550
1551
        template <typename U>
        bool destructively_aliases (
            const matrix_exp<U>& 
        ) const { return false; }

1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584

        iterator begin() 
        {
            if (size() != 0)
                return &data(0,0);
            else
                return 0;
        }

        iterator end()
        {
            if (size() != 0)
                return &data(0,0)+size();
            else
                return 0;
        }

        const_iterator begin()  const
        {
            if (size() != 0)
                return &data(0,0);
            else
                return 0;
        }

        const_iterator end() const
        {
            if (size() != 0)
                return &data(0,0)+size();
            else
                return 0;
        }

1585
    private:
1586
1587
1588
1589
1590
1591
1592
1593
        struct literal_assign_helper
        {
            /*
                This struct is a helper struct returned by the operator<<() function below.  It is
                used primarily to enable us to put DLIB_CASSERT statements on the usage of the
                operator<< form of matrix assignment.
            */

1594
1595
            literal_assign_helper(const literal_assign_helper& item) : m(item.m), r(item.r), c(item.c), has_been_used(false) {}
            literal_assign_helper(matrix* m_): m(m_), r(0), c(0),has_been_used(false) {next();}
1596
1597
            ~literal_assign_helper()
            {
1598
1599
                DLIB_CASSERT(!has_been_used || r == m->nr(),
                             "You have used the matrix comma based assignment incorrectly by failing to\n"
1600
1601
1602
1603
                             "supply a full set of values for every element of a matrix object.\n");
            }

            const literal_assign_helper& operator, (
1604
                const T& val
1605
1606
1607
            ) const
            {
                DLIB_CASSERT(r < m->nr() && c < m->nc(),
1608
                             "You have used the matrix comma based assignment incorrectly by attempting to\n" <<
1609
                             "supply more values than there are elements in the matrix object being assigned to.\n\n" <<
1610
1611
1612
1613
1614
                             "Did you forget to call set_size()?" 
                             << "\n\t r: " << r 
                             << "\n\t c: " << c 
                             << "\n\t m->nr(): " << m->nr()
                             << "\n\t m->nc(): " << m->nc());
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
                (*m)(r,c) = val;
                next();
                has_been_used = true;
                return *this;
            }

        private:

            void next (
            ) const
            {
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
                ++c;
                if (c == m->nc())
                {
                    c = 0;
                    ++r;
                }
            }

            matrix* m;
            mutable long r;
            mutable long c;
1637
            mutable bool has_been_used;
1638
1639
1640
1641
        };

    public:

1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
        const literal_assign_helper operator = (
            const T& val
        ) 
        {  
            // assign the given value to every spot in this matrix
            for (long r = 0; r < nr(); ++r)
            {
                for (long c = 0; c < nc(); ++c)
                {
                    data(r,c) = val;
                }
            }

            // Now return the literal_assign_helper so that the user
            // can use the overloaded comma notation to initialize 
            // the matrix if they want to.
            return literal_assign_helper(this); 
        }
1660
1661
1662
1663

    private:


1664
        typename layout::template layout<T,NR,NC,mem_manager> data;
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
    };

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

    template <
        typename T,
        long NR,
        long NC,
1675
1676
        typename mm,
        typename l
1677
1678
        >
    void swap(
1679
1680
        matrix<T,NR,NC,mm,l>& a,
        matrix<T,NR,NC,mm,l>& b
1681
1682
1683
1684
1685
1686
    ) { a.swap(b); }

    template <
        typename T,
        long NR,
        long NC,
1687
1688
        typename mm,
        typename l
1689
1690
        >
    void serialize (
1691
        const matrix<T,NR,NC,mm,l>& item, 
1692
1693
1694
1695
1696
        std::ostream& out
    )
    {
        try
        {
1697
1698
1699
1700
1701
1702
            // The reason the serialization is a little funny is because we are trying to
            // maintain backwards compatibility with an older serialization format used by
            // dlib while also encoding things in a way that lets the array2d and matrix
            // objects have compatible serialization formats.
            serialize(-item.nr(),out);
            serialize(-item.nc(),out);
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
            for (long r = 0; r < item.nr(); ++r)
            {
                for (long c = 0; c < item.nc(); ++c)
                {
                    serialize(item(r,c),out);
                }
            }
        }
        catch (serialization_error& e)
        {
            throw serialization_error(e.info + "\n   while serializing dlib::matrix");
        }
    }

    template <
        typename T,
        long NR,
        long NC,
1721
1722
        typename mm,
        typename l
1723
1724
        >
    void deserialize (
1725
        matrix<T,NR,NC,mm,l>& item, 
1726
1727
1728
1729
1730
1731
1732
1733
1734
        std::istream& in
    )
    {
        try
        {
            long nr, nc;
            deserialize(nr,in); 
            deserialize(nc,in); 

1735
1736
1737
1738
1739
1740
1741
            // this is the newer serialization format
            if (nr < 0 || nc < 0)
            {
                nr *= -1;
                nc *= -1;
            }

1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
            if (NR != 0 && nr != NR)
                throw serialization_error("Error while deserializing a dlib::matrix.  Invalid rows");
            if (NC != 0 && nc != NC)
                throw serialization_error("Error while deserializing a dlib::matrix.  Invalid columns");

            item.set_size(nr,nc);
            for (long r = 0; r < nr; ++r)
            {
                for (long c = 0; c < nc; ++c)
                {
                    deserialize(item(r,c),in);
                }
            }
        }
        catch (serialization_error& e)
        {
            throw serialization_error(e.info + "\n   while deserializing a dlib::matrix");
        }
    }

    template <
        typename EXP
        >
    std::ostream& operator<< (
        std::ostream& out,
        const matrix_exp<EXP>& m
    )
    {
        using namespace std;
        const streamsize old = out.width();

        // first figure out how wide we should make each field
        string::size_type w = 0;
        ostringstream sout;
        for (long r = 0; r < m.nr(); ++r)
        {
            for (long c = 0; c < m.nc(); ++c)
            {
                sout << m(r,c); 
                w = std::max(sout.str().size(),w);
                sout.str("");
            }
        }

        // now actually print it
        for (long r = 0; r < m.nr(); ++r)
        {
            for (long c = 0; c < m.nc(); ++c)
            {
                out.width(static_cast<streamsize>(w));
                out << m(r,c) << " ";
            }
            out << "\n";
        }
        out.width(old);
        return out;
    }

1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
    /*
    template <
        typename T, 
        long NR, 
        long NC,
        typename MM,
        typename L
        >
    std::istream& operator>> (
        std::istream& in,
        matrix<T,NR,NC,MM,L>& m
    );

    This function is defined inside the matrix_read_from_istream.h file.
    */

1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

    template <typename EXP>
    class const_temp_matrix;

    template <
        typename EXP
        >
    struct matrix_traits<const_temp_matrix<EXP> >
    {
        typedef typename EXP::type type;
1829
        typedef typename EXP::const_ret_type const_ret_type;
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
        typedef typename EXP::mem_manager_type mem_manager_type;
        typedef typename EXP::layout_type layout_type;
        const static long NR = EXP::NR;
        const static long NC = EXP::NC;
        const static long cost = 1;
    };

    template <typename EXP>
    class const_temp_matrix : public matrix_exp<const_temp_matrix<EXP> >, noncopyable 
    {
    public:
        typedef typename matrix_traits<const_temp_matrix>::type type;
1842
        typedef typename matrix_traits<const_temp_matrix>::const_ret_type const_ret_type;
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
        typedef typename matrix_traits<const_temp_matrix>::mem_manager_type mem_manager_type;
        typedef typename matrix_traits<const_temp_matrix>::layout_type layout_type;
        const static long NR = matrix_traits<const_temp_matrix>::NR;
        const static long NC = matrix_traits<const_temp_matrix>::NC;
        const static long cost = matrix_traits<const_temp_matrix>::cost;

        const_temp_matrix (
            const matrix_exp<EXP>& item
        ) :
            ref_(item.ref())
        {}
        const_temp_matrix (
            const EXP& item
        ) :
            ref_(item)
        {}

1860
        const_ret_type operator() (
1861
1862
1863
1864
            long r, 
            long c
        ) const { return ref_(r,c); }

1865
        const_ret_type operator() ( long i ) const 
1866
1867
        { return ref_(i); }

1868
        template <typename U>
1869
        bool aliases (
1870
            const matrix_exp<U>& item
1871
1872
        ) const { return ref_.aliases(item); }

1873
        template <typename U>
1874
        bool destructively_aliases (
1875
            const matrix_exp<U>& item
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
        ) const { return ref_.destructively_aliases(item); }

        long nr (
        ) const { return ref_.nr(); }

        long nc (
        ) const { return ref_.nc(); }

    private:

        typename conditional_matrix_temp<const EXP, (EXP::cost <= 1)>::type ref_;
    };

1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
// ----------------------------------------------------------------------------------------

}

#ifdef _MSC_VER
// put that warning back to its default setting
#pragma warning(default : 4355)
#endif

#endif // DLIB_MATRIx_