Unverified Commit 092fa303 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add softmax function for matrix type (#2320)

* Add softmax function for matrix type

* make softmax inherit from basic_op_m

* fix comment

* add test for matrix softmax

* remove include

* take inspiration from op_normalize

* use multiplication instead of division

* fix typo in documentation
parent 3162f93c
......@@ -442,6 +442,40 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <typename M>
struct op_softmax : basic_op_m<M>
{
typedef typename M::type type;
op_softmax(const M& m_, const type& s_, const type& v_) : basic_op_m<M>(m_), s(s_), v(v_){}
const type s;
const type v;
const static long cost = M::cost + 9;
typedef type const_ret_type;
const_ret_type apply(long r, long c) const { return std::exp(this->m(r, c) - v) * s; }
};
template <
typename EXP
>
const matrix_op<op_softmax<EXP> > soft_max (
const matrix_exp<EXP>& m
)
{
// you can only compute softmax on matrices that contain floats, doubles or long doubles.
COMPILE_TIME_ASSERT((
is_same_type<typename EXP::type,float>::value == true ||
is_same_type<typename EXP::type,double>::value == true ||
is_same_type<typename EXP::type,long double>::value == true
));
typedef op_softmax<EXP> op;
typename EXP::type max_val = max(m);
typename EXP::type temp = static_cast<typename EXP::type>(1) / sum(exp(m - max_val));
return matrix_op<op>(op(m.ref(), temp, max_val));
}
}
#endif // DLIB_MATRIx_MATH_FUNCTIONS
......
......@@ -587,6 +587,22 @@ namespace dlib
R(r,c) == std::tanh(m(r,c))
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp soft_max (
const matrix_exp& m
);
/*!
requires
- matrix_exp::type == float, double, or long double
ensures
- returns a matrix R such that:
- R::type == the same type that was in m
- R has the same dimensions as m
- for all valid r and c:
R(r,c) == std::exp(m(r,c)) / sum(std::exp(m))
!*/
// ----------------------------------------------------------------------------------------
}
......
......@@ -1518,6 +1518,23 @@ namespace
DLIB_TEST(equal(rowm(a,0) , trans(m*b)));
DLIB_TEST(!equal(rowm(a,0) , m*b));
}
{
matrix<double> x, y;
x = 10 * gaussian_randm(100, 1) - 10;
y = soft_max(x);
double max_val = -std::numeric_limits<double>::infinity();
for (const auto i : x)
max_val = std::max(max_val, i);
double sum_exps = 0;
for (const auto i : x)
sum_exps += std::exp(i - max_val);
double scale = 1.0 / sum_exps;
for (long i = 0; i < x.nr(); ++i)
DLIB_CASSERT(y(i) == std::exp(x(i) - max_val) * scale);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment