Commit 20b1d690 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into tests

parents 17aaaa1e ba729cfc
......@@ -44,13 +44,9 @@ struct is_fast_gemm_type<float> : std::true_type
{
};
template <class T>
void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat,
tensor_view<T> bmat,
float alpha,
float beta,
std::true_type)
template <class T, class F>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::true_type)
{
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
......@@ -66,13 +62,9 @@ void migemm_impl(tensor_view<T> cmat,
});
}
template <class T>
void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat,
tensor_view<T> bmat,
float alpha,
float beta,
std::false_type)
template <class T, class F>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::false_type)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
......@@ -95,9 +87,8 @@ void migemm_impl(tensor_view<T> cmat,
});
}
template <class T>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
template <class T, class F>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{
auto lens = amat.get_shape().lens();
bool batch_mul =
......@@ -113,13 +104,29 @@ void migemm_impl(
}
}
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
template <class F>
void migemm_tpl(
const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
{
visit_all(c_arg, a_arg, b_arg)(
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
}
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}
void migemm(const argument& c_arg,
const argument& a_arg,
const argument& b_arg,
int32_t alpha,
int32_t beta)
{
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment