Commit af454aeb authored by Paul's avatar Paul
Browse files

Formatting

parent 42e67e3d
...@@ -7,40 +7,37 @@ ...@@ -7,40 +7,37 @@
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
template<class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs) void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{ {
rocblas_sgemm(std::forward<Ts>(xs)...); rocblas_sgemm(std::forward<Ts>(xs)...);
} }
template<class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs) void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{ {
rocblas_dgemm(std::forward<Ts>(xs)...); rocblas_dgemm(std::forward<Ts>(xs)...);
} }
template<class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs) void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{ {
rocblas_hgemm(std::forward<Ts>(xs)...); rocblas_hgemm(std::forward<Ts>(xs)...);
} }
template<class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...) void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{ {
MIGRAPH_THROW("Type unsupported by rocblas"); MIGRAPH_THROW("Type unsupported by rocblas");
} }
template<class T> template <class T>
T to_rocblas_type(T x) T to_rocblas_type(T x)
{ {
return x; return x;
} }
rocblas_half to_rocblas_type(half x) rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
{
return reinterpret_cast<const rocblas_half&>(x);
}
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
...@@ -64,7 +61,8 @@ argument miopen_gemm::compute(context& ctx, ...@@ -64,7 +61,8 @@ argument miopen_gemm::compute(context& ctx,
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha)); auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
generic_rocblas_gemm(as, ctx.get_stream().get_rocblas(), generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment