"...composable_kernel_onnxruntime.git" did not exist on "e4584d91acc14a22426cbf081c8cc8394c136f6b"
Commit a0f9b785 authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Remove gemm copy and simplify rocblas call (#356)

* Remove extra copy in gemm

* combine rocblas gemm call

* clang format

* fix a bug in calling rocblas function

* clang format'

* backup of temporary changes

* clang format

* unify the gemm call to avoid multiple gpu implemantation

* clang format

* remove unnecessary code

* backup temp changes

* clang format

* fix cppcheck error

* code backup

* clang format

* remove unnecessary synchronization function

* clang format

* fix bugs

* clang format

* more optimization related to gemm

* clang format

* code cleanup

* implementation that can achieves better performance

* clang format

* temp changes to try performance

* clang format

* revert to previous commits

* fixed review comments

* clang format

* fix review comments
parent f445d962
...@@ -564,26 +564,19 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -564,26 +564,19 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih); auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
if(bias != prog.end()) if(bias != prog.end())
{ {
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h); hr_h = prog.insert_instruction(ins, op::add{}, hr_h, brb_h);
}
else
{
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
} }
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
instruction_ref ht1_rh{}; auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
if(bias != prog.end()) if(bias != prog.end())
{ {
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h); ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brb_h);
}
else
{
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
} }
hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh); hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
} }
......
...@@ -68,8 +68,6 @@ add_library(migraphx_gpu ...@@ -68,8 +68,6 @@ add_library(migraphx_gpu
hip.cpp hip.cpp
target.cpp target.cpp
lowering.cpp lowering.cpp
gemm.cpp
quant_gemm.cpp
pooling.cpp pooling.cpp
convolution.cpp convolution.cpp
quant_convolution.cpp quant_convolution.cpp
...@@ -93,6 +91,7 @@ add_library(migraphx_gpu ...@@ -93,6 +91,7 @@ add_library(migraphx_gpu
clip.cpp clip.cpp
int8_gemm_pack.cpp int8_gemm_pack.cpp
int8_conv_pack.cpp int8_conv_pack.cpp
gemm_impl.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_set_soversion(migraphx_gpu ${PROJECT_VERSION}) rocm_set_soversion(migraphx_gpu ${PROJECT_VERSION})
......
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <class... Ts>
rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
return rocblas_sscal(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
return rocblas_dscal(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
return rocblas_haxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
return rocblas_saxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
return rocblas_daxpy(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
return rocblas_sdot(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
return rocblas_ddot(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemv(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemv(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
return rocblas_hgemm(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
}
template <class T>
struct compute_rocblas_type
{
using type = T;
};
template <class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template <>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template <class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
rb_type<T>* to_rocblas_type(T* x)
{
return reinterpret_cast<rb_type<T>*>(x);
}
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
void miopen_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("DOT: batch size {" + to_string_range(strides) + "} is transposed!");
}
}
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
check_shapes{input_shapes}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
return op.compute_shape(input_shapes);
}
argument miopen_gemm::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
bool is_3inputs = (args.size() == 4);
float beta = 0.0f;
if(is_3inputs)
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemcpyAsync(to_pointer(args[3]),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice,
ctx.get_stream().get());
});
beta = op.beta;
}
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if(num_matrices == 1)
{
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
lda,
&beta_r,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc);
}
else
{
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
k * n,
to_pointer(args[0]),
lda,
m * k,
&beta_r,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc,
m * n,
num_matrices);
}
});
return (is_3inputs ? args[3] : args[2]);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/quant_gemm.hpp> #include <rocblas-types.h>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/generate.hpp>
#include <fstream>
#include <iomanip>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const rocblas_datatype get_type(shape::type_t type)
{ {
std::vector<shape> in_shapes(inputs); switch(type)
in_shapes.pop_back();
check_shapes{in_shapes}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
return op.compute_shape(in_shapes);
}
void rocblas_quant_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{ {
MIGRAPHX_THROW("QUANT_DOT: batch size {" + to_string_range(strides) + "} is transposed!"); case shape::double_type: return rocblas_datatype_f64_r;
case shape::float_type: return rocblas_datatype_f32_r;
case shape::half_type: return rocblas_datatype_f16_r;
case shape::int8_type: return rocblas_datatype_i8_r;
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::uint16_type:
case shape::int16_type:
case shape::int64_type:
case shape::uint64_type: MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
} }
MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
} }
argument rocblas_quant_gemm::compute(context& ctx, template <class T>
const shape& output_shape, void gemm_impl(
const std::vector<argument>& args) const context& ctx, const shape& output_shape, const std::vector<argument>& args, T alpha, T beta)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
...@@ -48,23 +39,32 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -48,23 +39,32 @@ argument rocblas_quant_gemm::compute(context& ctx,
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
int32_t beta = 0; if(!is_3inputs)
if(is_3inputs) {
beta = 0;
}
rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{ {
beta = op.beta; output_type = rocblas_datatype_i32_r;
} }
auto compute_type = output_type;
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(op.alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); }; auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
assert(k % 4 == 0); if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
}
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
...@@ -82,19 +82,19 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -82,19 +82,19 @@ argument rocblas_quant_gemm::compute(context& ctx,
k, k,
&alpha_r, &alpha_r,
to_pointer(args.at(1)), to_pointer(args.at(1)),
rocblas_datatype_i8_r, arg_type,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
rocblas_datatype_i8_r, arg_type,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, output_type,
ldc, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r, output_type,
ldc, ldc,
rocblas_datatype_i32_r, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
0, 0,
...@@ -112,24 +112,24 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -112,24 +112,24 @@ argument rocblas_quant_gemm::compute(context& ctx,
k, k,
&alpha_r, &alpha_r,
to_pointer(args.at(1)), to_pointer(args.at(1)),
rocblas_datatype_i8_r, arg_type,
ldb, ldb,
k * n, k * n,
to_pointer(args.at(0)), to_pointer(args.at(0)),
rocblas_datatype_i8_r, arg_type,
lda, lda,
m * k, m * k,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, output_type,
ldc, ldc,
m * n, m * n,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r, output_type,
ldc, ldc,
m * n, m * n,
num_matrices, num_matrices,
rocblas_datatype_i32_r, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
0, 0,
...@@ -137,8 +137,24 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -137,8 +137,24 @@ argument rocblas_quant_gemm::compute(context& ctx,
nullptr); nullptr);
} }
}); });
}
return is_3inputs ? args[3] : args[2]; void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta)
{
gemm_impl(ctx, output_shape, args, alpha, beta);
}
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta)
{
gemm_impl(ctx, output_shape, args, alpha, beta);
} }
} // namespace gpu } // namespace gpu
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/context.hpp>
#include <miopen/miopen.h> #include <miopen/miopen.h>
#include <vector> #include <vector>
...@@ -112,6 +113,18 @@ void copy_to_gpu(const argument& src, const argument& dst) ...@@ -112,6 +113,18 @@ void copy_to_gpu(const argument& src, const argument& dst)
MIGRAPHX_THROW("Copy to gpu failed: " + hip_error(status)); MIGRAPHX_THROW("Copy to gpu failed: " + hip_error(status));
} }
void gpu_copy(context& ctx, const argument& src, const argument& dst)
{
std::size_t src_size = src.get_shape().bytes();
std::size_t dst_size = dst.get_shape().bytes();
if(src_size > dst_size)
MIGRAPHX_THROW("Not enough memory available in destination to do copy");
auto status = hipMemcpyAsync(
dst.data(), src.data(), src_size, hipMemcpyDeviceToDevice, ctx.get_stream().get());
if(status != hipSuccess)
MIGRAPHX_THROW("Gpu copy failed: " + hip_error(status));
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -2,7 +2,11 @@ ...@@ -2,7 +2,11 @@
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP #define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -10,9 +14,10 @@ namespace gpu { ...@@ -10,9 +14,10 @@ namespace gpu {
struct context; struct context;
struct miopen_gemm template <class Op>
struct rocblas_gemm
{ {
op::dot op; Op op;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -20,11 +25,49 @@ struct miopen_gemm ...@@ -20,11 +25,49 @@ struct miopen_gemm
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
std::string name() const { return "gpu::gemm"; } std::string name() const
shape compute_shape(const std::vector<shape>& inputs) const; {
if(contains(op.name(), "quant_"))
{
return "gpu::quant_gemm";
}
return "gpu::gemm";
}
shape compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> in_shapes(inputs);
in_shapes.pop_back();
check_shapes{in_shapes}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
return op.compute_shape(in_shapes);
}
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
void batch_not_transposed(const std::vector<std::size_t>& strides) const; {
gemm(ctx, output_shape, args, op.alpha, op.beta);
return args.back();
}
void batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("GPU_GEMM: batch size {" + to_string_range(strides) +
"} is transposed!");
}
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta);
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -24,6 +24,8 @@ void gpu_sync(); ...@@ -24,6 +24,8 @@ void gpu_sync();
void copy_to_gpu(const argument& src, const argument& dst); void copy_to_gpu(const argument& src, const argument& dst);
void gpu_copy(context& ctx, const argument& src, const argument& dst);
struct hip_allocate struct hip_allocate
{ {
shape s; shape s;
...@@ -90,9 +92,9 @@ struct hip_write ...@@ -90,9 +92,9 @@ struct hip_write
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
struct hip_copy struct hip_copy_to_gpu
{ {
std::string name() const { return "hip_copy"; } std::string name() const { return "hip_copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2); check_shapes{inputs}.has(2);
...@@ -106,6 +108,22 @@ struct hip_copy ...@@ -106,6 +108,22 @@ struct hip_copy
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 1; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 1; }
}; };
struct hip_copy
{
std::string name() const { return "hip_copy"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).standard();
return inputs.at(1);
}
argument compute(context& ctx, const shape&, std::vector<argument> args) const
{
gpu_copy(ctx, args[0], args[1]);
return args[1];
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 1; }
};
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANT_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANT_GEMM_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/quant_dot.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct rocblas_quant_gemm
{
op::quant_dot op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::quant_gemm"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void batch_not_transposed(const std::vector<std::size_t>& strides) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -47,7 +47,6 @@ ...@@ -47,7 +47,6 @@
#include <migraphx/gpu/batchnorm.hpp> #include <migraphx/gpu/batchnorm.hpp>
#include <migraphx/gpu/pooling.hpp> #include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/concat.hpp> #include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/pad.hpp> #include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/gather.hpp> #include <migraphx/gpu/gather.hpp>
...@@ -120,8 +119,6 @@ struct miopen_apply ...@@ -120,8 +119,6 @@ struct miopen_apply
add_generic_op<hip_sign>("sign"); add_generic_op<hip_sign>("sign");
add_generic_op<hip_sigmoid>("sigmoid"); add_generic_op<hip_sigmoid>("sigmoid");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<rocblas_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<hip_softmax, op::softmax>("softmax"); add_extend_op<hip_softmax, op::softmax>("softmax");
...@@ -134,11 +131,12 @@ struct miopen_apply ...@@ -134,11 +131,12 @@ struct miopen_apply
add_extend_op<hip_clip, op::clip>("clip"); add_extend_op<hip_clip, op::clip>("clip");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum"); add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum");
add_extend_op<hip_reduce_mean, op::reduce_mean>("reduce_mean"); add_extend_op<hip_reduce_mean, op::reduce_mean>("reduce_mean");
add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot");
add_lrn_op(); add_lrn_op();
add_convolution_op(); add_convolution_op();
add_quant_convolution_op(); add_quant_convolution_op();
// add_quant_dot_op();
add_pooling_op(); add_pooling_op();
add_batch_norm_inference_op(); add_batch_norm_inference_op();
} }
...@@ -185,6 +183,38 @@ struct miopen_apply ...@@ -185,6 +183,38 @@ struct miopen_apply
}); });
} }
template <class Op>
void add_gemm_op(std::string name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<Op>(ins->get_operator());
auto beta = op.beta;
std::vector<instruction_ref> refs = ins->inputs();
if((refs.size() == 2) or (refs.size() == 3 and refs.back()->outputs().size() > 1) or
(ins == last))
{
auto output = insert_allocation(ins, ins->get_shape());
if(refs.size() == 2)
{
beta = 0;
refs.push_back(output);
}
else
{
auto copy_out = prog->insert_instruction(ins, hip_copy{}, refs.back(), output);
refs.back() = copy_out;
refs.push_back(copy_out);
}
}
else
{
refs.push_back(refs.back());
}
return prog->replace_instruction(ins, rocblas_gemm<Op>{Op{op.alpha, beta}}, refs);
});
}
void add_quant_convolution_op() void add_quant_convolution_op()
{ {
apply_map.emplace("quant_convolution", [=](instruction_ref ins) { apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
......
...@@ -45,7 +45,7 @@ void write_literals::apply(program& p) const ...@@ -45,7 +45,7 @@ void write_literals::apply(program& p) const
literal l = ins->get_literal(); literal l = ins->get_literal();
auto pre = p.add_literal(l); auto pre = p.add_literal(l);
auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()}); auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
p.replace_instruction(ins, hip_copy{}, pre, alloc); p.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment