Commit af454aeb authored by Paul's avatar Paul
Browse files

Formatting

parent 42e67e3d
......@@ -7,40 +7,37 @@
namespace migraph {
namespace gpu {
template<class... Ts>
template <class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
rocblas_sgemm(std::forward<Ts>(xs)...);
rocblas_sgemm(std::forward<Ts>(xs)...);
}
template<class... Ts>
template <class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
rocblas_dgemm(std::forward<Ts>(xs)...);
rocblas_dgemm(std::forward<Ts>(xs)...);
}
template<class... Ts>
template <class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
rocblas_hgemm(std::forward<Ts>(xs)...);
rocblas_hgemm(std::forward<Ts>(xs)...);
}
template<class T, class... Ts>
template <class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
MIGRAPH_THROW("Type unsupported by rocblas");
MIGRAPH_THROW("Type unsupported by rocblas");
}
template<class T>
template <class T>
T to_rocblas_type(T x)
{
return x;
return x;
}
rocblas_half to_rocblas_type(half x)
{
return reinterpret_cast<const rocblas_half&>(x);
}
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
......@@ -62,23 +59,24 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
generic_rocblas_gemm(as, ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
args[1].implicit(),
ldb,
args[0].implicit(),
lda,
&beta_r,
args[2].implicit(),
ldc);
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
args[1].implicit(),
ldb,
args[0].implicit(),
lda,
&beta_r,
args[2].implicit(),
ldc);
});
return args[2];
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment