Commit 20b1d690 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into tests

parents 17aaaa1e ba729cfc
#include <migraphx/gpu/device/sqrt.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqrt(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::sqrt(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -8,7 +8,7 @@ namespace device { ...@@ -8,7 +8,7 @@ namespace device {
void sub(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void sub(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) { return y - x; }); nary(stream, result, arg1, arg2)([](auto x, auto y) { return x - y; });
} }
} // namespace device } // namespace device
......
#include <migraphx/gpu/device/tanh.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void tanh(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::tanh(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2,14 +2,19 @@ ...@@ -2,14 +2,19 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device/add_relu.hpp> #include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_unary.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/array.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION)
struct fusion struct fusion
{ {
using op_t = miopenFusionOpDescriptor_t; using op_t = miopenFusionOpDescriptor_t;
...@@ -122,15 +127,10 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins) ...@@ -122,15 +127,10 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0; s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0;
} }
// TODO: Move to another header
template <class T, class... Ts>
std::array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{
return {std::move(x), std::move(static_cast<T>(xs))...};
}
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
{ {
if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{}))
return false;
if(ins->name() != "gpu::convolution") if(ins->name() != "gpu::convolution")
return false; return false;
if(ins->get_shape().type() != shape::float_type) if(ins->get_shape().type() != shape::float_type)
...@@ -140,11 +140,13 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -140,11 +140,13 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.op.group > 1) if(conv.op.group > 1)
return false; return false;
if(conv.op.padding_mode != op::padding_mode_t::default_)
return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false; return false;
auto op = conv.op; auto op = conv.op;
// Dont fuse winograd for non-3x3s since there is no fused windograd for those configs
if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and
wei.lens()[3] != 3 and op.stride == make_array<size_t>(1, 1))
return false;
return contains({{0, 0}, {1, 1}, {2, 2}}, op.padding) and return contains({{0, 0}, {1, 1}, {2, 2}}, op.padding) and
contains({{0, 0}, {1, 1}}, op.stride) and op.dilation == make_array<size_t>(1, 1); contains({{0, 0}, {1, 1}}, op.stride) and op.dilation == make_array<size_t>(1, 1);
} }
...@@ -168,9 +170,33 @@ struct hip_triadd ...@@ -168,9 +170,33 @@ struct hip_triadd
} }
}; };
struct hip_triadd_relu struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu>
{
};
struct hip_triadd_sigmoid : ternary_device<hip_triadd_sigmoid, &device::add_sigmoid>
{
};
struct hip_triadd_tanh : ternary_device<hip_triadd_tanh, &device::add_tanh>
{
};
struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
{ {
std::string name() const { return "hip::triadd_relu"; } };
struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid>
{
};
struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh>
{
};
struct hip_mul_add
{
std::string name() const { return "hip::mul_add"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(4); check_shapes{inputs, *this}.has(4);
...@@ -178,7 +204,7 @@ struct hip_triadd_relu ...@@ -178,7 +204,7 @@ struct hip_triadd_relu
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2)); device::mul_add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3); return args.at(3);
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
...@@ -187,18 +213,19 @@ struct hip_triadd_relu ...@@ -187,18 +213,19 @@ struct hip_triadd_relu
} }
}; };
struct hip_add_relu struct hip_mul_add_relu
{ {
std::string name() const { return "hip::add_relu"; } std::string name() const { return "hip::mul_add_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(4);
return inputs.front(); return inputs.front();
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::add_relu(ctx.get_stream().get(), args.at(2), args.at(0), args.at(1)); device::mul_add_relu(
return args.at(2); ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
...@@ -206,12 +233,40 @@ struct hip_add_relu ...@@ -206,12 +233,40 @@ struct hip_add_relu
} }
}; };
struct find_add_relu void move_broadcasted_back(std::vector<instruction_ref>& args)
{
// Ensure the last arguments is the broadcasted one
auto last = std::prev(args.end());
auto it =
std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != last)
std::swap(*it, *std::prev(last));
}
void move_standard_front(std::vector<instruction_ref>& args)
{ {
// Ensure the first arguments is the standard one
auto last = std::prev(args.end());
auto it =
std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
if(it != last)
std::swap(*it, args.front());
}
struct find_add_unary
{
std::string op_name;
operation binary_add_op;
operation ternary_add_op;
auto matcher() const auto matcher() const
{ {
return match::name("gpu::relu")(match::arg(0)( return match::name(op_name)(match::arg(0)(
match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add"))); match::used_once(),
match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of(match::name("@literal"),
match::any_of[match::inputs()](match::standard_shape())))
.bind("add")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -219,12 +274,15 @@ struct find_add_relu ...@@ -219,12 +274,15 @@ struct find_add_relu
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
auto args = add_ins->inputs(); auto args = add_ins->inputs();
move_standard_front(args);
move_broadcasted_back(args);
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, hip_add_relu{}, args); p.replace_instruction(ins, binary_add_op, args);
else if(add_ins->name() == "hip::triadd") else if(add_ins->name() == "hip::triadd")
p.replace_instruction(ins, hip_triadd_relu{}, args); p.replace_instruction(ins, ternary_add_op, args);
} }
}; };
...@@ -232,29 +290,78 @@ struct find_triadd ...@@ -232,29 +290,78 @@ struct find_triadd
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")(match::either_arg(0, 1)(match::name("gpu::add").bind("add"), return match::name("gpu::add")(match::either_arg(0, 1)(
match::any().bind("input"))); match::name("gpu::add")(match::used_once()).bind("add"),
match::any(match::any_of(match::name("@literal"),
match::any_of[match::inputs()](match::standard_shape())))
.bind("input")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"]; auto input_ins = r.instructions["input"];
auto ins = r.result; auto ins = r.result;
auto args = add_ins->inputs(); auto args = add_ins->inputs();
assert(add_ins != input_ins);
auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); }; auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1) if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
return; return;
args.insert(args.begin(), input_ins); args.insert(args.begin(), input_ins);
// Ensure the last arguments is the broadcasted one move_standard_front(args);
auto it = std::find_if(args.begin(), args.end(), is_broadcasted); move_broadcasted_back(args);
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args); p.replace_instruction(ins, hip_triadd{}, args);
} }
}; };
struct find_mul_add
{
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto mul_ins = r.instructions["mul"];
auto b_ins = r.instructions["b"];
auto ins = r.result;
auto args = mul_ins->inputs();
assert(mul_ins != b_ins);
move_standard_front(args);
move_broadcasted_back(args);
args.insert(std::prev(args.end()), b_ins);
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add{}, args);
}
};
struct find_mul_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(
match::arg(0)(match::name("hip::mul_add")(match::used_once()).bind("mul_add")));
}
void apply(program& p, match::matcher_result r) const
{
auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result;
auto args = mul_add_ins->inputs();
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add_relu{}, args);
}
};
struct miopen_conv_bias struct miopen_conv_bias
{ {
op::convolution op; op::convolution op;
...@@ -410,7 +517,11 @@ void fuse_ops::apply(program& p) const ...@@ -410,7 +517,11 @@ void fuse_ops::apply(program& p) const
match::find_matches(p, match::find_matches(p,
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_relu{} find_mul_add{},
find_mul_add_relu{},
find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}},
find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}
); );
// clang-format on // clang-format on
} }
......
...@@ -12,11 +12,9 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const ...@@ -12,11 +12,9 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
argument hip_gather::compute(context& ctx, argument hip_gather::compute(context& ctx, const shape&, const std::vector<argument>& args) const
const shape& output_shape,
const std::vector<argument>& args) const
{ {
return device::gather(ctx.get_stream().get(), output_shape, args, op.axis); return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
} }
} // namespace gpu } // namespace gpu
......
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <class... Ts>
rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
return rocblas_sscal(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
return rocblas_dscal(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
return rocblas_haxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
return rocblas_saxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
return rocblas_daxpy(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
return rocblas_sdot(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
return rocblas_ddot(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemv(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemv(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
return rocblas_hgemm(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
}
template <class T>
struct compute_rocblas_type
{
using type = T;
};
template <class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template <>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template <class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
rb_type<T>* to_rocblas_type(T* x)
{
return reinterpret_cast<rb_type<T>*>(x);
}
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes);
}
argument miopen_gemm::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
bool is_3inputs = (args.size() == 4);
float beta = 0.0f;
if(is_3inputs)
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemcpyAsync(to_pointer(args[3]),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice,
ctx.get_stream().get());
});
beta = op.beta;
}
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if(num_matrices == 1)
{
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
lda,
&beta_r,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc);
}
else
{
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
k * n,
to_pointer(args[0]),
lda,
m * k,
&beta_r,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc,
m * n,
num_matrices);
}
});
return (is_3inputs ? args[3] : args[2]);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <rocblas-types.h>
#include <migraphx/gpu/gemm_impl.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
rocblas_datatype get_type(shape::type_t type)
{
switch(type)
{
case shape::double_type: return rocblas_datatype_f64_r;
case shape::float_type: return rocblas_datatype_f32_r;
case shape::half_type: return rocblas_datatype_f16_r;
case shape::int8_type: return rocblas_datatype_i8_r;
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::uint16_type:
case shape::int16_type:
case shape::int64_type:
case shape::uint64_type: MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
}
MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
}
template <class T>
void gemm_impl(
context& ctx, const shape& output_shape, const std::vector<argument>& args, T alpha, T beta)
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
bool is_3inputs = (args.size() == 4);
if(!is_3inputs)
{
beta = 0;
}
rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
auto compute_type = output_type;
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
}
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1)
{
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
&beta_r,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
compute_type,
rocblas_gemm_algo_standard,
0,
0,
nullptr,
nullptr);
}
else
{
rocblas_gemm_strided_batched_ex(
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
arg_type,
ldb,
k * n,
to_pointer(args.at(0)),
arg_type,
lda,
m * k,
&beta_r,
to_pointer(args[2]),
output_type,
ldc,
m * n,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
m * n,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
0,
nullptr,
nullptr);
}
});
}
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta)
{
gemm_impl(ctx, output_shape, args, alpha, beta);
}
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta)
{
gemm_impl(ctx, output_shape, args, alpha, beta);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/context.hpp>
#include <miopen/miopen.h> #include <miopen/miopen.h>
#include <vector> #include <vector>
...@@ -112,6 +113,18 @@ void copy_to_gpu(const argument& src, const argument& dst) ...@@ -112,6 +113,18 @@ void copy_to_gpu(const argument& src, const argument& dst)
MIGRAPHX_THROW("Copy to gpu failed: " + hip_error(status)); MIGRAPHX_THROW("Copy to gpu failed: " + hip_error(status));
} }
void gpu_copy(context& ctx, const argument& src, const argument& dst)
{
std::size_t src_size = src.get_shape().bytes();
std::size_t dst_size = dst.get_shape().bytes();
if(src_size > dst_size)
MIGRAPHX_THROW("Not enough memory available in destination to do copy");
auto status = hipMemcpyAsync(
dst.data(), src.data(), src_size, hipMemcpyDeviceToDevice, ctx.get_stream().get());
if(status != hipSuccess)
MIGRAPHX_THROW("Gpu copy failed: " + hip_error(status));
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_argmax
{
op::argmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_argmin
{
op::argmin op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -12,7 +10,7 @@ namespace gpu { ...@@ -12,7 +10,7 @@ namespace gpu {
struct context; struct context;
struct hip_convert : unary_device<hip_convert, device::convert> struct hip_convert
{ {
op::convert op; op::convert op;
...@@ -22,13 +20,15 @@ struct hip_convert : unary_device<hip_convert, device::convert> ...@@ -22,13 +20,15 @@ struct hip_convert : unary_device<hip_convert, device::convert>
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
hip_convert(op::convert oper) : op(oper) {} std::string name() const { return "gpu::convert"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
inputs.pop_back(); return shapes.size() - 1;
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
} }
}; };
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ADD_UNARY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ADD_UNARY_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2);
void add_sigmoid(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2);
void add_tanh(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2);
void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
void add_sigmoid(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
void add_tanh(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARG_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARG_OP_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T>
struct val_index
{
T val;
int64_t index;
};
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v)
{
return {v, -1};
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
{
return {v, i};
}
struct argmax_op
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val > y.val)
return x;
else if(x.val < y.val)
return y;
else
{
return (x.index < y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
};
struct argmin_op
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val < y.val)
return x;
else if(x.val > y.val)
return y;
else
{
return (x.index < y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
};
template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
auto arg_shape = arg.get_shape();
auto lens = arg_shape.lens();
auto batch_lens = lens;
size_t batch_item_num = lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{arg_shape.type(), batch_lens};
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch.
const size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto batch_idx = batch_s.multi(i / block_size);
auto data_idx = batch_idx;
auto init = make_val_index<type>(op.init());
auto op_output =
block_reduce<max_block_size>(idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return make_val_index(input[arg_s.index(data_idx)], j);
});
if(idx.local == 0)
{
output[batch_s.index(batch_idx)] = op_output.index;
}
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMAX_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMIN_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMIN_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_DIV_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_DIV_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ERF_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ERF_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void erf(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis);
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ADD_RELU_HPP #define MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ADD_RELU_HPP
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -11,16 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,16 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void add_relu(hipStream_t stream, void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg);
const argument& result,
const argument& arg1,
const argument& arg2);
void add_relu(hipStream_t stream, void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg);
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument logsoftmax(hipStream_t stream, void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis);
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment