Commit 7d986afb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup of the extension of the gemm implementation.

parent 0ea7b7a3
...@@ -5,6 +5,84 @@ namespace migraphx { ...@@ -5,6 +5,84 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
rocblas_sscal(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
rocblas_dscal(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}
template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
rocblas_haxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
rocblas_saxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
rocblas_daxpy(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}
template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
rocblas_sdot(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
rocblas_ddot(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}
template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
rocblas_sgemv(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
rocblas_dgemv(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}
template <class... Ts> template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs) void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{ {
...@@ -92,10 +170,90 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const ...@@ -92,10 +170,90 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
std::size_t miopen_gemm::compute_offset(std::vector<std::size_t>& out_lens,
std::size_t index, std::vector<std::size_t> &data_lens) const
{
}
argument miopen_gemm::compute(context& ctx, argument miopen_gemm::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
bool is_3inputs = (args.size() == 4);
if (output_shape.elements() == 1)
{
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_dot(as, ctx.get_stream().get_rocblas(),
args[1].get_shape().elements(),
to_pointer(args[0]),
1,
to_pointer(args[1]),
1,
is_3inputs ? to_pointer(args[3]): to_pointer(args[2]));
generic_rocblas_scal(as, ctx.get_stream().get_rocblas(),
1,
&alpha_r,
is_3inputs ? to_pointer(args[3]): to_pointer(args[2]));
1);
if (is_3inputs)
{
generic_rocblas_axpy(as, ctx.get_stream().get_rocblas(),
1,
&beta_r,
to_pointer(args[2]),
1,
to_pointer(args[3]),
1);
}
});
return is_3inputs ? args[3] : args[2];
}
// b is a vector, so the computation is matrix * vector
// could not be the case of inner product of vectors since
// it is already processed above
if (args[1].get_shape().lens().size() == 1)
{
// considering the batch input, so A could be a batch
// of matrices
auto a_lens = args[0].get_shape().lens();
std::size_t n_dims = a_lens.size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
bool transa = args[0].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int m = a_lens[dim_0];
rocblas_int k = a_lens[dim_1];
auto batch_num = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg, std::size_t offset) { return to_rocblas_type(as.from(arg.data() + offset)); };
for (std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
{
if(is_3inputs)
hipMemcpy(to_pointer(args[3] + batch_no * m),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice);
else
hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
}
});
}
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
std::size_t n_dims = args[0].get_shape().lens().size(); std::size_t n_dims = args[0].get_shape().lens().size();
...@@ -112,8 +270,13 @@ argument miopen_gemm::compute(context& ctx, ...@@ -112,8 +270,13 @@ argument miopen_gemm::compute(context& ctx,
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
// two input arguments
if (!is_3inputs)
{
}
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); };
if(is_3inputs) if(is_3inputs)
hipMemcpy(to_pointer(args[3]), hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]), to_pointer(args[2]),
...@@ -124,9 +287,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -124,9 +287,7 @@ argument miopen_gemm::compute(context& ctx,
}); });
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); };
auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(as, generic_rocblas_batched_gemm(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
......
...@@ -18,6 +18,9 @@ struct miopen_gemm ...@@ -18,6 +18,9 @@ struct miopen_gemm
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private:
std::size_t compute_offset(std::vector<std::size_t>& out_lens, std::size_t index, std::vector<std::size_t> &data_lens) const;
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment