".github/vscode:/vscode.git/clone" did not exist on "7ffc5b4418ad297fe05ddfe8007db38b3eb54d8b"
Commit af454aeb authored by Paul's avatar Paul
Browse files

Formatting

parent 42e67e3d
...@@ -7,40 +7,37 @@ ...@@ -7,40 +7,37 @@
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
template<class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs) void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{ {
rocblas_sgemm(std::forward<Ts>(xs)...); rocblas_sgemm(std::forward<Ts>(xs)...);
} }
template<class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs) void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{ {
rocblas_dgemm(std::forward<Ts>(xs)...); rocblas_dgemm(std::forward<Ts>(xs)...);
} }
template<class... Ts> template <class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs) void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{ {
rocblas_hgemm(std::forward<Ts>(xs)...); rocblas_hgemm(std::forward<Ts>(xs)...);
} }
template<class T, class... Ts> template <class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...) void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{ {
MIGRAPH_THROW("Type unsupported by rocblas"); MIGRAPH_THROW("Type unsupported by rocblas");
} }
template<class T> template <class T>
T to_rocblas_type(T x) T to_rocblas_type(T x)
{ {
return x; return x;
} }
rocblas_half to_rocblas_type(half x) rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
{
return reinterpret_cast<const rocblas_half&>(x);
}
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
...@@ -62,23 +59,24 @@ argument miopen_gemm::compute(context& ctx, ...@@ -62,23 +59,24 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int n = output_shape.lens()[1]; rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[1];
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha)); auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
generic_rocblas_gemm(as, ctx.get_stream().get_rocblas(), generic_rocblas_gemm(as,
transb ? rocblas_operation_transpose : rocblas_operation_none, ctx.get_stream().get_rocblas(),
transa ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
n, transa ? rocblas_operation_transpose : rocblas_operation_none,
m, n,
k, m,
&alpha_r, k,
args[1].implicit(), &alpha_r,
ldb, args[1].implicit(),
args[0].implicit(), ldb,
lda, args[0].implicit(),
&beta_r, lda,
args[2].implicit(), &beta_r,
ldc); args[2].implicit(),
ldc);
}); });
return args[2]; return args[2];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment