"git@developer.sourcefind.cn:change/sglang.git" did not exist on "977f785dad98540f01bca34abe6c6fd326fd6a7c"
Commit 195c7694 authored by Davis King's avatar Davis King
Browse files

Made column major matrices directly wrap matlab matrix objects when used

inside mex files.  This way, if you use matrix_colmajor or fmatrix_colmajor
in a mex file it will not do any unnecessary copying or transposing.
parent 6178838e
......@@ -261,6 +261,20 @@ namespace mex_binding
std::string msg;
};
// -------------------------------------------------------
template <typename T>
struct is_column_major_matrix : public default_is_kind_value {};
template <
typename T,
long num_rows,
long num_cols,
typename mem_manager
>
struct is_column_major_matrix<matrix<T,num_rows,num_cols,mem_manager,column_major_layout> >
{ static const bool value = true; };
// -------------------------------------------------------
template <
......@@ -563,7 +577,10 @@ namespace mex_binding
sout << " argument " << arg_idx+1 << " must be a matrix of doubles";
throw invalid_args_exception(sout.str());
}
assign_mat(arg_idx, arg , pointer_to_matrix(mxGetPr(prhs), nc, nr));
if (is_column_major_matrix<T>::value)
arg._private_set_mxArray((mxArray*)prhs);
else
assign_mat(arg_idx, arg , pointer_to_matrix(mxGetPr(prhs), nc, nr));
}
else if (is_same_type<type, float>::value)
{
......@@ -574,7 +591,10 @@ namespace mex_binding
throw invalid_args_exception(sout.str());
}
assign_mat(arg_idx, arg , pointer_to_matrix((const float*)mxGetData(prhs), nc, nr));
if (is_column_major_matrix<T>::value)
arg._private_set_mxArray((mxArray*)prhs);
else
assign_mat(arg_idx, arg , pointer_to_matrix((const float*)mxGetData(prhs), nc, nr));
}
else if (is_same_type<type, bool>::value)
{
......@@ -922,6 +942,26 @@ namespace mex_binding
}
}
void assign_to_matlab(
mxArray*& plhs,
matrix_colmajor& item
)
{
// Don't need to do a copy if it's this kind of matrix since we can just
// pull the underlying mxArray out directly and thus avoid a copy.
plhs = item._private_release_mxArray();
}
void assign_to_matlab(
mxArray*& plhs,
fmatrix_colmajor& item
)
{
// Don't need to do a copy if it's this kind of matrix since we can just
// pull the underlying mxArray out directly and thus avoid a copy.
plhs = item._private_release_mxArray();
}
void assign_to_matlab(
mxArray*& plhs,
matlab_struct& item
......@@ -989,6 +1029,14 @@ namespace mex_binding
{
}
// ----------------------------------------------------------------------------------------
template <typename T>
void mark_non_persistent (const T&){}
void mark_non_persistent(matrix_colmajor& item) { item._private_mark_non_persistent(); }
void mark_non_persistent(fmatrix_colmajor& item) { item._private_mark_non_persistent(); }
// ----------------------------------------------------------------------------------------
template <
......@@ -1010,6 +1058,8 @@ namespace mex_binding
typename basic_type<arg1_type>::type A1;
mark_non_persistent(A1);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
......@@ -1036,6 +1086,9 @@ namespace mex_binding
typename basic_type<arg1_type>::type A1;
typename basic_type<arg2_type>::type A2;
mark_non_persistent(A1);
mark_non_persistent(A2);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1066,6 +1119,10 @@ namespace mex_binding
typename basic_type<arg2_type>::type A2;
typename basic_type<arg3_type>::type A3;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1100,6 +1157,11 @@ namespace mex_binding
typename basic_type<arg3_type>::type A3;
typename basic_type<arg4_type>::type A4;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1139,6 +1201,12 @@ namespace mex_binding
typename basic_type<arg4_type>::type A4;
typename basic_type<arg5_type>::type A5;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1183,6 +1251,13 @@ namespace mex_binding
typename basic_type<arg5_type>::type A5;
typename basic_type<arg6_type>::type A6;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
mark_non_persistent(A6);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1231,6 +1306,14 @@ namespace mex_binding
typename basic_type<arg6_type>::type A6;
typename basic_type<arg7_type>::type A7;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
mark_non_persistent(A6);
mark_non_persistent(A7);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1283,6 +1366,15 @@ namespace mex_binding
typename basic_type<arg7_type>::type A7;
typename basic_type<arg8_type>::type A8;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
mark_non_persistent(A6);
mark_non_persistent(A7);
mark_non_persistent(A8);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1339,6 +1431,16 @@ namespace mex_binding
typename basic_type<arg8_type>::type A8;
typename basic_type<arg9_type>::type A9;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
mark_non_persistent(A6);
mark_non_persistent(A7);
mark_non_persistent(A8);
mark_non_persistent(A9);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1400,6 +1502,17 @@ namespace mex_binding
typename basic_type<arg9_type>::type A9;
typename basic_type<arg10_type>::type A10;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
mark_non_persistent(A6);
mark_non_persistent(A7);
mark_non_persistent(A8);
mark_non_persistent(A9);
mark_non_persistent(A10);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......
......@@ -17,6 +17,10 @@
#include "matrix_op.h"
#include <utility>
#ifdef MATLAB_MEX_FILE
#include <mex.h>
#endif
#ifdef _MSC_VER
// Disable the following warnings for Visual Studio
......@@ -1239,6 +1243,26 @@ namespace dlib
return data(0);
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray(
mxArray* mem
)
{
data._private_set_mxArray(mem);
}
mxArray* _private_release_mxArray(
)
{
return data._private_release_mxArray();
}
void _private_mark_non_persistent()
{
data._private_mark_non_persistent();
}
#endif
void set_size (
long rows,
long cols
......@@ -1971,6 +1995,9 @@ namespace dlib
// ----------------------------------------------------------------------------------------
typedef matrix<double,0,0,default_memory_manager,column_major_layout> matrix_colmajor;
typedef matrix<float,0,0,default_memory_manager,column_major_layout> fmatrix_colmajor;
}
#ifdef _MSC_VER
......
......@@ -645,7 +645,18 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
/*!A matrix_colmajor
This is just a typedef of the matrix object that uses column major layout.
!*/
typedef matrix<double,0,0,default_memory_manager,column_major_layout> matrix_colmajor;
/*!A fmatrix_colmajor
This is just a typedef of the matrix object that uses column major layout.
!*/
typedef matrix<float,0,0,default_memory_manager,column_major_layout> fmatrix_colmajor;
// ----------------------------------------------------------------------------------------
template <
typename T,
long NR,
long NC,
......
......@@ -6,6 +6,9 @@
#include "../algs.h"
#include "matrix_fwd.h"
#include "matrix_data_layout_abstract.h"
#ifdef MATLAB_MEX_FILE
#include <mex.h>
#endif
// GCC 4.8 gives false alarms about some matrix operations going out of bounds. Disable
// these false warnings.
......@@ -180,6 +183,12 @@ namespace dlib
{
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T data[num_rows*num_cols];
};
......@@ -243,6 +252,12 @@ namespace dlib
{
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -318,6 +333,12 @@ namespace dlib
nr_ = nr;
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -396,6 +417,12 @@ namespace dlib
nc_ = nc;
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -476,6 +503,11 @@ namespace dlib
nc_ = nc;
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
long nr_;
......@@ -593,6 +625,12 @@ namespace dlib
{
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T data[num_cols*num_rows];
};
......@@ -656,6 +694,12 @@ namespace dlib
{
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -731,6 +775,12 @@ namespace dlib
nr_ = nr;
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -809,6 +859,12 @@ namespace dlib
nc_ = nc;
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -869,6 +925,12 @@ namespace dlib
pool.swap(item.pool);
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
long nr (
) const { return nr_; }
......@@ -894,7 +956,251 @@ namespace dlib
long nr_;
long nc_;
typename mem_manager::template rebind<T>::other pool;
};
};
#ifdef MATLAB_MEX_FILE
template <
long num_rows,
long num_cols
>
class layout<double,num_rows,num_cols,default_memory_manager,5> : noncopyable // when num_rows == 0 && num_cols == 0
{
public:
const static long NR = num_rows;
const static long NC = num_cols;
layout (
): data(0), nr_(0), nc_(0), make_persistent(true),set_by_private_set_mxArray(false),mem(0) { }
~layout ()
{
if (!set_by_private_set_mxArray && mem)
{
mxDestroyArray(mem);
mem = 0;
data = 0;
}
}
double& operator() (
long r,
long c
) { return data[c*nr_ + r]; }
const double& operator() (
long r,
long c
) const { return data[c*nr_ + r]; }
double& operator() (
long i
) { return data[i]; }
const double& operator() (
long i
) const { return data[i]; }
void _private_set_mxArray (
mxArray* mem_
)
{
// We don't own the pointer, so make note of that so we won't try to free
// it.
set_by_private_set_mxArray = true;
mem = mem_;
data = mxGetPr(mem);
nr_ = mxGetM(mem);
nc_ = mxGetN(mem);
}
mxArray* _private_release_mxArray()
{
DLIB_CASSERT(!make_persistent,"");
mxArray* temp = mem;
mem = 0;
set_by_private_set_mxArray = false;
data = 0;
nr_ = 0;
nc_ = 0;
return temp;
}
void _private_mark_non_persistent()
{
make_persistent = false;
}
void swap(
layout& item
)
{
std::swap(item.make_persistent,make_persistent);
std::swap(item.set_by_private_set_mxArray,set_by_private_set_mxArray);
std::swap(item.mem,mem);
std::swap(item.data,data);
std::swap(item.nc_,nc_);
std::swap(item.nr_,nr_);
}
long nr (
) const { return nr_; }
long nc (
) const { return nc_; }
void set_size (
long nr,
long nc
)
{
if (!set_by_private_set_mxArray && mem)
{
mxDestroyArray(mem);
mem = 0;
data = 0;
}
set_by_private_set_mxArray = false;
mem = mxCreateDoubleMatrix(nr, nc, mxREAL);
if (mem == 0)
throw std::bad_alloc();
if (make_persistent)
mexMakeArrayPersistent(mem);
data = mxGetPr(mem);
nr_ = nr;
nc_ = nc;
}
private:
double* data;
long nr_;
long nc_;
bool make_persistent;
bool set_by_private_set_mxArray;
mxArray* mem;
};
template <
long num_rows,
long num_cols
>
class layout<float,num_rows,num_cols,default_memory_manager,5> : noncopyable // when num_rows == 0 && num_cols == 0
{
public:
const static long NR = num_rows;
const static long NC = num_cols;
layout (
): data(0), nr_(0), nc_(0), make_persistent(true),set_by_private_set_mxArray(false),mem(0) { }
~layout ()
{
if (!set_by_private_set_mxArray && mem)
{
mxDestroyArray(mem);
mem = 0;
data = 0;
}
}
float& operator() (
long r,
long c
) { return data[c*nr_ + r]; }
const float& operator() (
long r,
long c
) const { return data[c*nr_ + r]; }
float& operator() (
long i
) { return data[i]; }
const float& operator() (
long i
) const { return data[i]; }
void _private_set_mxArray (
mxArray* mem_
)
{
// We don't own the pointer, so make note of that so we won't try to free
// it.
set_by_private_set_mxArray = true;
mem = mem_;
data = (float*)mxGetData(mem);
nr_ = mxGetM(mem);
nc_ = mxGetN(mem);
}
mxArray* _private_release_mxArray()
{
DLIB_CASSERT(!make_persistent,"");
mxArray* temp = mem;
mem = 0;
set_by_private_set_mxArray = false;
data = 0;
nr_ = 0;
nc_ = 0;
return temp;
}
void _private_mark_non_persistent()
{
make_persistent = false;
}
void swap(
layout& item
)
{
std::swap(item.make_persistent,make_persistent);
std::swap(item.set_by_private_set_mxArray,set_by_private_set_mxArray);
std::swap(item.mem,mem);
std::swap(item.data,data);
std::swap(item.nc_,nc_);
std::swap(item.nr_,nr_);
}
long nr (
) const { return nr_; }
long nc (
) const { return nc_; }
void set_size (
long nr,
long nc
)
{
if (!set_by_private_set_mxArray && mem)
{
mxDestroyArray(mem);
mem = 0;
data = 0;
}
set_by_private_set_mxArray = false;
mem = mxCreateNumericMatrix(nr, nc, mxSINGLE_CLASS, mxREAL);
if (mem == 0)
throw std::bad_alloc();
if (make_persistent)
mexMakeArrayPersistent(mem);
data = (float*)mxGetData(mem);
nr_ = nr;
nc_ = nc;
}
private:
float* data;
long nr_;
long nc_;
bool make_persistent;
bool set_by_private_set_mxArray;
mxArray* mem;
};
#endif
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment