Commit 201838d7 authored by Paul's avatar Paul
Browse files

Fix rocblas errors

parent cf86db72
......@@ -32,10 +32,37 @@ void generic_rocblas_gemm(shape::as<T>, Ts&&...)
MIGRAPH_THROW("Type unsupported by rocblas");
}
template<class T>
struct compute_rocblas_type
{
using type = T;
};
template<class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template<>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template<class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
T to_rocblas_type(T x)
rb_type<T>* to_rocblas_type(T* x)
{
return x;
return reinterpret_cast<rb_type<T>*>(x);
}
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
......@@ -62,6 +89,9 @@ argument miopen_gemm::compute(context& ctx,
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) {
return to_rocblas_type(as.from(arg.data()));
};
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
......@@ -70,12 +100,12 @@ argument miopen_gemm::compute(context& ctx,
m,
k,
&alpha_r,
args[1].implicit(),
to_pointer(args[1]),
ldb,
args[0].implicit(),
to_pointer(args[0]),
lda,
&beta_r,
args[2].implicit(),
to_pointer(args[2]),
ldc);
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment