Commit 20b1d690 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into tests

parents 17aaaa1e ba729cfc
...@@ -44,13 +44,9 @@ struct is_fast_gemm_type<float> : std::true_type ...@@ -44,13 +44,9 @@ struct is_fast_gemm_type<float> : std::true_type
{ {
}; };
template <class T> template <class T, class F>
void migemm_impl(tensor_view<T> cmat, void migemm_impl(
tensor_view<T> amat, tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::true_type)
tensor_view<T> bmat,
float alpha,
float beta,
std::true_type)
{ {
visit_mat(amat, [&](const auto& a) { visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) { visit_mat(bmat, [&](const auto& b) {
...@@ -66,13 +62,9 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -66,13 +62,9 @@ void migemm_impl(tensor_view<T> cmat,
}); });
} }
template <class T> template <class T, class F>
void migemm_impl(tensor_view<T> cmat, void migemm_impl(
tensor_view<T> amat, tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::false_type)
tensor_view<T> bmat,
float alpha,
float beta,
std::false_type)
{ {
std::size_t n_dims = cmat.get_shape().lens().size(); std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2; std::size_t dim_0 = n_dims - 2;
...@@ -95,9 +87,8 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -95,9 +87,8 @@ void migemm_impl(tensor_view<T> cmat,
}); });
} }
template <class T> template <class T, class F>
void migemm_impl( void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{ {
auto lens = amat.get_shape().lens(); auto lens = amat.get_shape().lens();
bool batch_mul = bool batch_mul =
...@@ -113,13 +104,29 @@ void migemm_impl( ...@@ -113,13 +104,29 @@ void migemm_impl(
} }
} }
void migemm( template <class F>
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta) void migemm_tpl(
const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
{ {
visit_all(c_arg, a_arg, b_arg)( visit_all(c_arg, a_arg, b_arg)(
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); }); [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
} }
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}
void migemm(const argument& c_arg,
const argument& a_arg,
const argument& b_arg,
int32_t alpha,
int32_t beta)
{
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -10,6 +10,11 @@ namespace cpu { ...@@ -10,6 +10,11 @@ namespace cpu {
void migemm( void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta); const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
void migemm(const argument& c_arg,
const argument& a_arg,
const argument& b_arg,
int32_t alpha,
int32_t beta);
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,6 +15,10 @@ struct target ...@@ -15,6 +15,10 @@ struct target
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& ctx) const; std::vector<pass> get_passes(migraphx::context& ctx) const;
migraphx::context get_context() const { return context{}; } migraphx::context get_context() const { return context{}; }
argument copy_to(const argument& arg) const { return arg; }
argument copy_from(const argument& arg) const { return arg; }
argument allocate(const shape& s) const;
}; };
} // namespace cpu } // namespace cpu
......
...@@ -2,7 +2,21 @@ ...@@ -2,7 +2,21 @@
#include <migraphx/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
...@@ -204,6 +218,61 @@ struct cpu_convolution ...@@ -204,6 +218,61 @@ struct cpu_convolution
} }
}; };
struct cpu_quant_convolution
{
op::quant_convolution op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::quant_convolution"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto output = result.get<int32_t>();
visit_all(args[0], args[1])([&](auto input, auto weights) {
auto in = input.get_shape().lens();
auto in_h = in[2];
auto in_w = in[3];
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto wei_h = wei[2];
auto wei_w = wei[3];
par_dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const auto start_x = i * op.stride[0] - op.padding[0];
const auto start_y = j * op.stride[1] - op.padding[1];
const auto group_id = w / (wei_n / op.group);
int32_t acc = 0;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
const auto in_x = start_x + x;
const auto in_y = start_y + y;
const auto in_ch = group_id * wei_c + k;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{
acc += static_cast<int32_t>(input(o, in_ch, in_x, in_y)) *
weights(w, k, x, y);
}
});
output(o, w, i, j) = acc;
});
});
return result;
}
};
struct cpu_im2col struct cpu_im2col
{ {
op::im2col op; op::im2col op;
...@@ -233,17 +302,17 @@ struct cpu_im2col ...@@ -233,17 +302,17 @@ struct cpu_im2col
const std::size_t& stride_h = op.stride[0]; const std::size_t& stride_h = op.stride[0];
const std::size_t& stride_w = op.stride[1]; const std::size_t& stride_w = op.stride[1];
auto kdiv2_h = kernel_h / 2; long kdiv2_h = long(kernel_h) / 2;
auto kdiv2_w = kernel_w / 2; long kdiv2_w = long(kernel_w) / 2;
// calculate output sizes // calculate output sizes
const std::size_t col_height = (height - kernel_h + 2 * pad_h) / stride_h + 1; const std::size_t col_height = (height - kernel_h + 2 * pad_h) / stride_h + 1;
const std::size_t col_width = (width - kernel_w + 2 * pad_w) / stride_w + 1; const std::size_t col_width = (width - kernel_w + 2 * pad_w) / stride_w + 1;
// account for padding for the starting position of the input pixels // account for padding for the starting position of the input pixels
std::size_t iinput = kdiv2_h - pad_h; long iinput = kdiv2_h - long(pad_h);
// loop over output pixels (ioutput, joutput) // loop over output pixels (ioutput, joutput)
for(std::size_t ioutput = 0; ioutput < col_height; ioutput++, iinput += stride_h) for(std::size_t ioutput = 0; ioutput < col_height; ioutput++, iinput += stride_h)
{ {
std::size_t jinput = kdiv2_w - pad_w; long jinput = kdiv2_w - long(pad_w);
for(std::size_t joutput = 0; joutput < col_width; joutput++, jinput += stride_w) for(std::size_t joutput = 0; joutput < col_width; joutput++, jinput += stride_w)
{ {
// compute linear index for output // compute linear index for output
...@@ -252,8 +321,8 @@ struct cpu_im2col ...@@ -252,8 +321,8 @@ struct cpu_im2col
dfor(channels, dfor(channels,
kernel_h, kernel_h,
kernel_w)([&](std::size_t c, std::size_t koffset, std::size_t loffset) { kernel_w)([&](std::size_t c, std::size_t koffset, std::size_t loffset) {
auto idx = iinput + koffset - kdiv2_h; auto idx = iinput + long(koffset) - kdiv2_h;
auto jdx = jinput + loffset - kdiv2_w; auto jdx = jinput + long(loffset) - kdiv2_w;
col(ldx, p) = ((idx >= 0) && (idx < height) && (jdx >= 0) && (jdx < width)) col(ldx, p) = ((idx >= 0) && (idx < height) && (jdx >= 0) && (jdx < width))
? input(0, c, idx, jdx) ? input(0, c, idx, jdx)
: 0; : 0;
...@@ -421,7 +490,7 @@ struct cpu_gemm ...@@ -421,7 +490,7 @@ struct cpu_gemm
{ {
argument result{output_shape}; argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then // 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B // A and B are matrices, and C is of the same shape as A * B
if(args.size() == 3) if(args.size() == 3)
{ {
// no need to consider the value of args[2] // no need to consider the value of args[2]
...@@ -448,13 +517,79 @@ struct cpu_gemm ...@@ -448,13 +517,79 @@ struct cpu_gemm
} }
}; };
struct cpu_quant_gemm
{
op::quant_dot op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::quant_dot"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted();
}
return op.compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrices, and C is of the same shape to A * B
// first, convert the args[0] and args[1] from int8_t to int32_t
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
arg_0.visit([&](auto output) {
args.at(0).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
arg_1.visit([&](auto output) {
args.at(1).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
if(args.size() == 3)
{
// no need to consider the value of args[2]
if(op.beta == 0)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
}
else
{
visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
migemm(result, arg_0, arg_1, op.alpha, op.beta);
return result;
}
// 2 input arguments
migemm(result, arg_0, arg_1, op.alpha, int32_t{0});
return result;
}
};
struct leaky_relu_op struct leaky_relu_op
{ {
op::leaky_relu op; op::leaky_relu op;
std::string name() const { return "cpu::leaky_relu"; } std::string name() const { return "cpu::leaky_relu"; }
auto fcn() const auto fcn() const
{ {
auto& a = op.alpha; auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : x * a; }; return [a](auto x) { return x > 0 ? x : x * a; };
} }
}; };
...@@ -465,7 +600,7 @@ struct elu_op ...@@ -465,7 +600,7 @@ struct elu_op
std::string name() const { return "cpu::elu"; } std::string name() const { return "cpu::elu"; }
auto fcn() const auto fcn() const
{ {
auto& a = op.alpha; auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : a * std::expm1(x); }; return [a](auto x) { return x > 0 ? x : a * std::expm1(x); };
} }
}; };
...@@ -517,40 +652,60 @@ struct cpu_unary ...@@ -517,40 +652,60 @@ struct cpu_unary
} }
}; };
struct softmax2d struct cpu_softmax
{ {
std::string name() const { return "cpu::softmax2d"; } op::softmax op;
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens();
std::size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
auto nb = input.get_shape().lens()[0]; std::vector<value_type> batch_max(batch_shape.elements(),
auto nc = input.get_shape().lens()[1]; std::numeric_limits<value_type>::lowest());
auto nh = input.get_shape().lens()[2]; std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
auto nw = input.get_shape().lens()[3]; par_for(batch_shape.elements(), [&](auto i) {
dfor(nb, nh, nw)([&](std::size_t b, std::size_t i, std::size_t j) { auto idx = batch_shape.multi(i);
value_type cmax = std::numeric_limits<value_type>::lowest(); for(std::size_t j = 0; j < n_dims; ++j)
for(std::size_t c = 0; c < nc; c++)
{ {
cmax = std::max(cmax, input(b, c, i, j)); idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
} }
for(std::size_t c = 0; c < nc; c++)
for(std::size_t j = 0; j < n_dims; ++j)
{ {
output(b, c, i, j) = std::exp(input(b, c, i, j) - cmax); idx[op.axis] = j;
std::size_t index = output_shape.index(idx);
output[index] = std::exp(input[index] - batch_max[i]);
} }
value_type sum = value_type(0);
for(std::size_t c = 0; c < nc; c++) for(std::size_t j = 0; j < n_dims; ++j)
{ {
sum += output(b, c, i, j); idx[op.axis] = j;
batch_sum[i] += output(idx.begin(), idx.end());
} }
for(std::size_t c = 0; c < nc; c++)
for(std::size_t j = 0; j < n_dims; ++j)
{ {
output(b, c, i, j) = output(b, c, i, j) / sum; idx[op.axis] = j;
output(idx.begin(), idx.end()) /= batch_sum[i];
} }
}); });
}); });
return result; return result;
} }
}; };
...@@ -567,63 +722,50 @@ struct cpu_logsoftmax ...@@ -567,63 +722,50 @@ struct cpu_logsoftmax
std::string name() const { return "cpu::logsoftmax"; } std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const
{
if(axis == 0)
{
return 0;
}
else
{
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis);
return batch_shape.index(batch_idx.begin(), batch_idx.end());
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::vector<std::size_t> batch_lens{}; std::size_t n_dims = batch_lens[op.axis];
if(op.axis == 0) batch_lens[op.axis] = 1;
{ shape batch_shape{shape::int32_type, batch_lens};
batch_lens.push_back(1);
} // use a parallel implementation to acheive better performance
else // one thread for one batch
{
batch_lens.insert(batch_lens.begin(), lens.begin(), lens.begin() + op.axis);
}
shape batch_shape{migraphx::shape::uint32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) { par_for(batch_shape.elements(), [&](auto i) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto idx = batch_shape.multi(i);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index]; for(std::size_t j = 0; j < n_dims; ++j)
}); {
idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
}
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); for(std::size_t j = 0; j < n_dims; ++j)
shape_for_each(output_shape, [&](auto idx) { {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); idx[op.axis] = j;
batch_sum[index] += std::exp(output(idx.begin(), idx.end())); std::size_t index = output_shape.index(idx);
}); output[index] = input[index] - batch_max[i];
}
for(std::size_t j = 0; j < n_dims; ++j)
{
idx[op.axis] = j;
batch_sum[i] += std::exp(output(idx.begin(), idx.end()));
}
for(std::size_t i = 0; i < batch_sum.size(); ++i)
{
batch_sum[i] = std::log(batch_sum[i]); batch_sum[i] = std::log(batch_sum[i]);
}
shape_for_each(output_shape, [&](auto idx) { for(std::size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) -= batch_sum[index]; idx[op.axis] = j;
output(idx.begin(), idx.end()) -= batch_sum[i];
}
}); });
}); });
...@@ -652,15 +794,17 @@ struct cpu_apply ...@@ -652,15 +794,17 @@ struct cpu_apply
{ {
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>(); apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
apply_map["dot"] = extend_op<cpu_gemm, op::dot>(); apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>(); apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>();
apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>(); apply_map["quant_convolution"] = extend_op<cpu_quant_convolution, op::quant_convolution>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>(); apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>(); apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>(); apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>(); apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>();
apply_map["softmax"] = simple_op<softmax2d>(); apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["softmax"] = extend_op<cpu_softmax, op::softmax>();
} }
void apply() void apply()
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/generate.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -22,6 +23,8 @@ std::vector<pass> target::get_passes(migraphx::context&) const ...@@ -22,6 +23,8 @@ std::vector<pass> target::get_passes(migraphx::context&) const
dead_code_elimination{}}; dead_code_elimination{}};
} }
argument target::allocate(const shape& s) const { return fill_argument(s, 0); }
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -12,65 +12,89 @@ endif() ...@@ -12,65 +12,89 @@ endif()
add_library(migraphx_device add_library(migraphx_device
device/add.cpp device/add.cpp
device/argmax.cpp
device/argmin.cpp
device/max.cpp device/max.cpp
device/min.cpp device/min.cpp
device/mul_add.cpp
device/exp.cpp device/exp.cpp
device/erf.cpp
device/log.cpp device/log.cpp
device/sin.cpp device/sin.cpp
device/cos.cpp device/cos.cpp
device/tan.cpp device/tan.cpp
device/sinh.cpp device/sinh.cpp
device/cosh.cpp device/cosh.cpp
device/tanh.cpp
device/asin.cpp device/asin.cpp
device/acos.cpp device/acos.cpp
device/atan.cpp device/atan.cpp
device/add_relu.cpp device/relu.cpp
device/add_unary.cpp
device/contiguous.cpp device/contiguous.cpp
device/logsoftmax.cpp device/logsoftmax.cpp
device/softmax.cpp
device/sigmoid.cpp
device/convert.cpp device/convert.cpp
device/mul.cpp device/mul.cpp
device/concat.cpp device/concat.cpp
device/pad.cpp device/pad.cpp
device/gather.cpp device/gather.cpp
device/sub.cpp device/sub.cpp
device/int8_gemm_pack.cpp
device/div.cpp
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp
device/rsqrt.cpp
device/round.cpp
device/sqrt.cpp
device/reduce_mean.cpp
device/pow.cpp
device/sqdiff.cpp
device/sign.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${PROJECT_VERSION})
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
target_link_libraries(migraphx_device migraphx hip::device -Wno-invalid-command-line-argument -amdgpu-target=gfx803 -amdgpu-target=gfx900 -amdgpu-target=gfx906) target_link_libraries(migraphx_device migraphx hip::device -Wno-invalid-command-line-argument -amdgpu-target=gfx803 -amdgpu-target=gfx900 -amdgpu-target=gfx906)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
add_library(migraphx_gpu add_library(migraphx_gpu
argmax.cpp
argmin.cpp
eliminate_workspace.cpp eliminate_workspace.cpp
fuse_ops.cpp fuse_ops.cpp
hip.cpp hip.cpp
target.cpp target.cpp
lowering.cpp lowering.cpp
gemm.cpp
pooling.cpp pooling.cpp
convolution.cpp convolution.cpp
quant_convolution.cpp
softmax.cpp softmax.cpp
logsoftmax.cpp logsoftmax.cpp
contiguous.cpp contiguous.cpp
concat.cpp concat.cpp
relu.cpp
leaky_relu.cpp leaky_relu.cpp
tanh.cpp
batchnorm.cpp batchnorm.cpp
write_literals.cpp write_literals.cpp
rocblas.cpp rocblas.cpp
sigmoid.cpp
abs.cpp abs.cpp
elu.cpp elu.cpp
pad.cpp pad.cpp
gather.cpp gather.cpp
convert.cpp
lrn.cpp lrn.cpp
schedule_model.cpp schedule_model.cpp
adjust_allocation.cpp adjust_allocation.cpp
pack_int8_args.cpp
clip.cpp clip.cpp
int8_gemm_pack.cpp
int8_conv_pack.cpp
gemm_impl.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_set_soversion(migraphx_gpu ${PROJECT_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device) target_link_libraries(migraphx_gpu PRIVATE migraphx_device)
......
#include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::argmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::argmin(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_convert::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
}
argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::convert(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/add_relu.hpp> #include <migraphx/gpu/device/add_unary.hpp>
#include <migraphx/gpu/device/nary.hpp> #include <migraphx/gpu/device/nary.hpp>
namespace migraphx { namespace migraphx {
...@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto a, auto b) { return std::max<decltype(a * x + b)>(0, a * x + b); });
}
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
...@@ -15,6 +25,23 @@ void add_relu(hipStream_t stream, ...@@ -15,6 +25,23 @@ void add_relu(hipStream_t stream,
[](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); }); [](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); });
} }
void add_sigmoid(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{
nary(stream, result, arg1, arg2)(
[](auto x, auto y) { return 1.f / (1.f + ::exp(to_hip_type(-(x + y)))); });
}
void add_tanh(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return ::tanh(to_hip_type(x + y)); });
}
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
...@@ -25,6 +52,26 @@ void add_relu(hipStream_t stream, ...@@ -25,6 +52,26 @@ void add_relu(hipStream_t stream,
[](auto x, auto y, auto z) { return std::max<decltype(x + y + z)>(0, x + y + z); }); [](auto x, auto y, auto z) { return std::max<decltype(x + y + z)>(0, x + y + z); });
} }
void add_sigmoid(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return 1.f / (1.f + ::exp(to_hip_type(-(x + y + z)))); });
}
void add_tanh(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return ::tanh(to_hip_type(x + y + z)); });
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/arg_op.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg_op(argmax_op{}, stream, result, arg, axis);
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/argmin.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/arg_op.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg_op(argmin_op{}, stream, result, arg, axis);
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -10,22 +10,22 @@ namespace gpu { ...@@ -10,22 +10,22 @@ namespace gpu {
namespace device { namespace device {
argument concat(hipStream_t stream, argument concat(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape&,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
std::vector<std::size_t> offsets) std::vector<std::size_t> offsets)
{ {
for(std::size_t l = 0; l < args.size() - 1; l++) auto ninputs = args.size() - 1;
for(std::size_t j = 0; j < ninputs; j++)
{ {
auto argl = args[l]; auto&& arg = args[j];
std::size_t nelements = argl.get_shape().elements(); std::size_t nelements = arg.get_shape().elements();
visit_all(args.back(), argl)([&](auto output, auto input) { auto offset = offsets[j];
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { shape arg_shape{arg.get_shape().type(), arg.get_shape().lens()};
auto* outptr = output.data() + offsets[l]; hip_visit_all(args.back(), arg, arg_shape)([&](auto output, auto input, auto input_shape) {
const auto* inptr = input.data(); gs_launch(stream, nelements)([=](auto i) {
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); auto input_idx = input_shape.multi(i);
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); auto idx = output.get_shape().index(input_idx);
gs_launch(stream, nelements)( output.data()[idx + offset] = input[input_idx];
[=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
}); });
}); });
} }
......
#include <migraphx/gpu/device/div.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return x / y; });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/erf.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void erf(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::erf(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -11,35 +11,30 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,35 +11,30 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis)
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis)
{ {
auto axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis; auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { auto& input_shape = arg1.get_shape();
std::size_t nelements = output_shape.elements(); auto lens = input_shape.lens();
args[1].visit([&](auto indices) { lens[axis_index] = arg2.get_shape().elements();
const auto* indices_ptr = device_cast(indices.data()); shape out_comp_shape{result.get_shape().type(), lens};
auto* out_ptr = device_cast(output.data()); std::size_t nelements = result.get_shape().elements();
const auto* in_ptr = device_cast(input.data());
auto& input_shape = args[0].get_shape(); visit_all(result, arg1)([&](auto output, auto input_v) {
auto lens = input_shape.lens(); hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
lens[axis_index] = args[1].get_shape().elements(); arg2.visit([&](auto indices) {
migraphx::shape out_comp_shape{output_shape.type(), lens}; const auto* indices_ptr = device_cast(indices.data());
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) { auto* output_ptr = device_cast(output.data());
hip_tensor_descriptor<n_out_dim> desc_input(input_shape); gs_launch(stream, nelements, 256)([=](auto i) {
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape); auto idx = out_comp.multi(i);
gs_launch(stream, nelements)([=](auto ii) { idx[axis_index] = indices_ptr[idx[axis_index]];
auto in_idx = desc_output.multi(ii); output_ptr[i] = input[idx];
in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
}); });
}); });
}); });
}); });
return args.back(); return result;
} }
} // namespace device } // namespace device
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARRAY_HPP
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T, std::size_t N>
struct hip_array
{
T d[N];
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T& front() { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& front() const { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR T& back() { return d[N - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& back() const { return d[N - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR std::integral_constant<std::size_t, N> size() const { return {}; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR T dot(const hip_array& x) const
{
T result = 0;
for(std::size_t i = 0; i < N; i++)
result += x[i] * d[i];
return result;
}
MIGRAPHX_DEVICE_CONSTEXPR T product() const
{
T result = 1;
for(std::size_t i = 0; i < N; i++)
result *= d[i];
return result;
}
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator*(const hip_array& x, const hip_array& y)
{
hip_array result;
for(std::size_t i = 0; i < N; i++)
result[i] = x[i] * y[i];
return result;
}
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator+(const hip_array& x, const hip_array& y)
{
hip_array result{};
for(std::size_t i = 0; i < N; i++)
result[i] = x[i] + y[i];
return result;
}
};
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -11,9 +11,33 @@ namespace device { ...@@ -11,9 +11,33 @@ namespace device {
struct index struct index
{ {
std::size_t global; std::size_t global = 0;
std::size_t local; std::size_t local = 0;
std::size_t group; std::size_t group = 0;
__device__ std::size_t nglobal() const { return blockDim.x * gridDim.x; } // NOLINT
__device__ std::size_t nlocal() const { return blockDim.x; } // NOLINT
template <class F>
__device__ void global_stride(std::size_t n, F f) const
{
const auto stride = nglobal();
for(std::size_t i = global; i < n; i += stride)
{
f(i);
}
}
template <class F>
__device__ void local_stride(std::size_t n, F f) const
{
const auto stride = nlocal();
for(std::size_t i = local; i < n; i += stride)
{
f(i);
}
}
}; };
template <class F> template <class F>
...@@ -35,18 +59,26 @@ inline auto launch(hipStream_t stream, std::size_t global, std::size_t local) ...@@ -35,18 +59,26 @@ inline auto launch(hipStream_t stream, std::size_t global, std::size_t local)
}; };
} }
template <class F>
__host__ __device__ auto gs_invoke(F&& f, std::size_t i, index idx) -> decltype(f(i, idx))
{
return f(i, idx);
}
template <class F>
__host__ __device__ auto gs_invoke(F&& f, std::size_t i, index) -> decltype(f(i))
{
return f(i);
}
inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024) inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024)
{ {
std::size_t groups = 1 + n / local; std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local; std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) { launch(stream, nglobal, local)(
for(size_t i = idx.global; i < n; i += nglobal) [=](auto idx) { idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); }); });
{
f(i);
}
});
}; };
} }
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP #define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/visit.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/array.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -13,57 +13,30 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,57 +13,30 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T>
using vec4 = T __attribute__((ext_vector_type(4)));
template <class T>
__device__ __host__ vec4<T>* as_vec4(T* x)
{
return reinterpret_cast<vec4<T>*>(x);
}
template <class T>
__device__ __host__ T* as_pointer(vec4<T>* x)
{
return reinterpret_cast<T*>(x);
}
template <class... Ts> template <class... Ts>
auto pack_vec4(Ts... xs) auto pack(Ts... xs) __device__
{ {
return [=](auto f, std::size_t n) { return f(as_vec4(xs)[n]...); }; return [=](auto f) { return f(xs...); };
} }
template <class F, class... Arguments> template <class F, class... Arguments>
auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args) auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
const auto& output_shape = result.get_shape(); std::size_t nelements = result.get_shape().elements();
visit_all(result, args...)([&](auto output, auto... inputs) { hip_visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { gs_launch(stream, nelements)([=](auto i) {
auto data = pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, auto idx = output.get_shape().multi(i);
device_cast(inputs.data()))...); output[i] = f(inputs[idx]...);
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())([=](auto i) {
data([&](auto&&... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
});
}); });
}); });
} }
template <class F> template <class F, class... Arguments>
void trinary_broadcast_vec_impl(hipStream_t stream, void nary_broadcast_vec_impl(
F f, hipStream_t stream, F f, argument result, argument barg, Arguments... args)
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape(); const auto& b_shape = barg.get_shape();
auto bdim = auto bdim =
std::distance(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
...@@ -73,56 +46,45 @@ void trinary_broadcast_vec_impl(hipStream_t stream, ...@@ -73,56 +46,45 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) { const std::size_t vec_size = 4;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; const std::size_t nlocal = 1024;
auto* xp = as_vec4(device_cast(input1.data())); const std::size_t nglobal = 256 * nlocal;
auto* yp = as_vec4(device_cast(input2.data())); const std::size_t bdim_vec_len = bdim_len / vec_size;
auto* zp = as_vec4(device_cast(input3.data())); hip_vec_visit_all<vec_size>(result, barg, args...)(
auto* outp = as_vec4(device_cast(output.data())); [&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
const std::size_t vec_size = 4; const std::size_t nelements = output.size() / vec_size;
const std::size_t nlocal = 1024; launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size; MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
const std::size_t bdim_vec_len = bdim_len / vec_size; // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = zp[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> y = yp[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{ {
out[j] = f(x[j], y[j], b); buffer[i] = binput.data()[i];
} }
outp[i] = out; __syncthreads();
} auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(inputs.data()[i][j]..., b);
}
output.data()[i] = out;
}
});
}); });
});
} }
template <class F> template <class F, class... Arguments>
void trinary_broadcast_impl(hipStream_t stream, void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args)
F f,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape(); const auto& b_shape = barg.get_shape();
auto bdim = auto bdim =
std::distance(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
...@@ -132,44 +94,39 @@ void trinary_broadcast_impl(hipStream_t stream, ...@@ -132,44 +94,39 @@ void trinary_broadcast_impl(hipStream_t stream,
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) { const std::size_t nlocal = 1024;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; const std::size_t nglobal = 256 * nlocal;
auto* xp = device_cast(input1.data()); std::size_t nelements = result.get_shape().elements();
auto* yp = device_cast(input2.data()); hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
auto* zp = device_cast(input3.data()); using type = typename decltype(output)::value_type;
auto* outp = device_cast(output.data());
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size();
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048]; MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal) for(size_t i = idx.local; i < bdim_len; i += nlocal)
{ {
buffer[i] = zp[i]; buffer[i] = binput.data()[i];
} }
__syncthreads(); __syncthreads();
// Process the data // Process the data
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
auto bidx = (i % bdim_next_stride) / bdim_stride; auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx]; auto b = buffer[bidx];
type x = xp[i]; output.data()[i] = f(inputs.data()[i]..., b);
type y = yp[i];
outp[i] = f(x, y, b);
} }
}); });
}); });
} }
template <class F> template <class F, class... Arguments>
void binary_broadcast_vec_impl( void nary_double_broadcast_vec_impl(
hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2) hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{ {
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = barg1.get_shape();
auto bdim = auto bdim =
std::distance(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
...@@ -179,50 +136,54 @@ void binary_broadcast_vec_impl( ...@@ -179,50 +136,54 @@ void binary_broadcast_vec_impl(
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { const std::size_t vec_size = 4;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; const std::size_t nlocal = 1024;
auto* xp = as_vec4(device_cast(input1.data())); const std::size_t nglobal = 256 * nlocal;
auto* yp = as_vec4(device_cast(input2.data())); const std::size_t bdim_vec_len = bdim_len / vec_size;
auto* outp = as_vec4(device_cast(output.data())); hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
const std::size_t vec_size = 4; using type = typename decltype(output)::value_type;
const std::size_t nlocal = 1024; const std::size_t nelements = output.size() / vec_size;
const std::size_t nglobal = 256 * nlocal; launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size; MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
MIGRAPHX_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = yp[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{ {
out[j] = f(x[j], b); buffer[i] = binput1.data()[i];
} }
outp[i] = out; for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
} {
buffer[i + bdim_vec_len] = binput2.data()[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b1 = bp[bidx];
auto b2 = bp[bidx + bdim_len];
auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(inputs.data()[i][j]..., b2, b1);
}
output.data()[i] = out;
}
});
}); });
});
} }
template <class F> template <class F, class... Arguments>
void binary_broadcast_impl( void nary_double_broadcast_impl(
hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2) hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{ {
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = barg1.get_shape();
auto bdim = auto bdim =
std::distance(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
...@@ -232,48 +193,47 @@ void binary_broadcast_impl( ...@@ -232,48 +193,47 @@ void binary_broadcast_impl(
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { const std::size_t nlocal = 1024;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; const std::size_t nglobal = 256 * nlocal;
auto* xp = device_cast(input1.data()); std::size_t nelements = result.get_shape().elements();
auto* yp = device_cast(input2.data()); hip_visit_all(result, barg1, barg2, args...)(
auto* outp = device_cast(output.data()); [&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type;
const std::size_t nlocal = 1024; launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
const std::size_t nglobal = 256 * nlocal; MIGRAPHX_DEVICE_SHARED type buffer[2048];
const std::size_t n = output.size(); // Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { {
MIGRAPHX_DEVICE_SHARED type buffer[2048]; buffer[i] = binput1.data()[i];
// Load bias into LDS }
for(size_t i = idx.local; i < bdim_len; i += nlocal) for(size_t i = idx.local; i < bdim_len; i += nlocal)
{ {
buffer[i] = yp[i]; buffer[i + bdim_len] = binput2.data()[i];
} }
__syncthreads(); __syncthreads();
// Process the data // Process the data
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
auto bidx = (i % bdim_next_stride) / bdim_stride; auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx]; auto b1 = buffer[bidx];
type x = xp[i]; auto b2 = buffer[bidx + bdim_len];
outp[i] = f(x, b); output.data()[i] = f(inputs.data()[i]..., b2, b1);
} }
});
}); });
});
} }
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
auto data = pack_vec4(device_cast(inputs.data())...); auto data = pack_vec<4>(device_cast(inputs.data())...);
auto* outp = as_vec4(device_cast(output.data())); auto* outp = as_vec<4>(device_cast(output.data()));
gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) { gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i]; vec<type, 4> out = outp[i];
data( data(
[&](auto... xs) { [&](auto... xs) {
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
...@@ -290,13 +250,9 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments. ...@@ -290,13 +250,9 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements()); std::size_t nelements = result.get_shape().elements();
const auto& output_shape = result.get_shape(); hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) {
visit_all(result, args...)([&](auto output, auto... inputs) { gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); });
auto data = pack(device_cast(inputs.data())...);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
}); });
} }
...@@ -313,12 +269,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args) ...@@ -313,12 +269,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
nary_nonstandard_impl(stream, f, result, args...); nary_nonstandard_impl(stream, f, result, args...);
} }
template <class F>
void nary_impl(hipStream_t stream, F f, argument result)
{
nary_standard_impl(stream, f, result);
}
template <class... Arguments> template <class... Arguments>
auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args) auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args)
{ {
...@@ -332,71 +282,114 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args) ...@@ -332,71 +282,114 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
} }
template <class... Arguments> template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args) bool broadcastable(bool& divisible_by_4,
std::size_t max_size,
const argument& result,
const argument& barg,
const Arguments&... args)
{ {
return [=](auto f) { nary_impl(stream, f, result, args...); }; divisible_by_4 = false;
auto bshape = barg.get_shape();
const bool standard =
all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = bshape.strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= max_size and std::none_of(std::next(b_it), strides.end(), not_zero))
{
divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
return true;
}
}
return false;
}
inline bool broadcastable(bool& divisible_by_4, std::size_t, const argument&, const argument&)
{
divisible_by_4 = false;
return false;
} }
inline auto // Nullary
nary(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) inline auto nary(hipStream_t stream, argument result)
{
return [=](auto f) { nary_standard_impl(stream, f, result); };
}
// Unary
inline auto nary(hipStream_t stream, argument result, argument arg)
{
return [=](auto f) { nary_impl(stream, f, result, arg); };
}
// Binary
inline auto nary(hipStream_t stream, argument result, argument arg, argument barg)
{ {
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same bool divisible_by_4 = false;
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and if(broadcastable(divisible_by_4, 2048, result, barg, arg))
not arg2.get_shape().scalar())
{ {
auto not_zero = [](auto x) { return x != 0; }; if(divisible_by_4)
const auto& strides = arg2.get_shape().strides(); nary_broadcast_vec_impl(stream, f, result, barg, arg);
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero); else
auto b_idx = std::distance(strides.begin(), b_it); nary_broadcast_impl(stream, f, result, barg, arg);
auto b_len = result.get_shape().lens()[b_idx]; }
auto b_stride = result.get_shape().strides()[b_idx]; else
assert(arg2.get_shape().lens()[b_idx] == b_len); {
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero)) nary_impl(stream, f, result, arg, barg);
{
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4)
binary_broadcast_vec_impl(stream, f, result, arg1, arg2);
else
binary_broadcast_impl(stream, f, result, arg1, arg2);
return;
}
} }
nary_impl(stream, f, result, arg1, arg2);
}; };
} }
inline auto nary(hipStream_t stream, template <class... Arguments>
const argument& result, auto nary(hipStream_t stream, argument result, Arguments... args)
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
static_assert(sizeof...(args) > 2, "Args needs to be greater than 2");
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same auto barg1 = back_args(args...);
if(arg1.get_shape().standard() and arg2.get_shape().standard() and bool fallback1 = pop_back_args(args...)([&](auto&&... args2) {
arg3.get_shape().broadcasted()) auto barg2 = back_args(args2...);
{ bool fallback2 =
auto not_zero = [](auto x) { return x != 0; }; barg2.get_shape() != barg1.get_shape() or not barg2.get_shape().broadcasted() or
const auto& strides = arg3.get_shape().strides(); pop_back_args(args2...)([&](auto&&... args3) {
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero); bool divisible_by_4 = false;
auto b_idx = std::distance(strides.begin(), b_it); if(broadcastable(divisible_by_4, 1024, result, barg2, args3...))
auto b_len = result.get_shape().lens()[b_idx]; {
auto b_stride = result.get_shape().strides()[b_idx]; if(divisible_by_4)
assert(arg3.get_shape().lens()[b_idx] == b_len); nary_double_broadcast_vec_impl(
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero)) stream, f, result, barg1, barg2, args3...);
else
nary_double_broadcast_impl(stream, f, result, barg1, barg2, args3...);
return false;
}
return true;
});
if(not fallback2)
return false;
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 2048, result, barg1, args2...))
{ {
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4) if(divisible_by_4)
trinary_broadcast_vec_impl(stream, f, result, arg1, arg2, arg3); nary_broadcast_vec_impl(stream, f, result, barg1, args2...);
else else
trinary_broadcast_impl(stream, f, result, arg1, arg2, arg3); nary_broadcast_impl(stream, f, result, barg1, args2...);
return; return false;
} }
} return true;
nary_impl(stream, f, result, arg1, arg2, arg3); });
if(fallback1)
nary_impl(stream, f, result, args...);
}; };
} }
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_HPP
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/visit.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
struct sum
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x + y;
}
};
struct id
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x;
}
};
struct mean
{
size_t item_num = 1;
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return static_cast<T>(x / item_num);
}
};
struct max
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x > y ? x : y;
}
};
struct min
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x < y ? x : y;
}
};
struct lowest
{
template <class T>
operator T() const
{
return device_cast(std::numeric_limits<host_type<T>>::lowest());
}
};
struct highest
{
template <class T>
operator T() const
{
return device_cast(std::numeric_limits<host_type<T>>::max());
}
};
#ifdef MIGRAPHX_NO_DPP
template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{
using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x;
__syncthreads();
for(std::size_t s = 1; s < idx.nlocal(); s *= 2)
{
const std::size_t index = 2 * s * idx.local;
if(index + s < idx.nlocal())
{
buffer[index] = op(buffer[index], buffer[index + s]);
}
__syncthreads();
}
return buffer[0];
}
#else
constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; }
constexpr unsigned int dpp_row_bcast(unsigned int x)
{
unsigned int y = 0;
switch(x)
{
case 15: y = 0x142; break;
case 31: y = 0x143; break;
default: throw std::runtime_error("Unknown bcast");
}
return y;
}
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
static const std::size_t n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type
{
uint32_t reg[n];
T data;
};
type output{};
type input{};
// cppcheck-suppress unreadVariable
input.data = x;
for(std::size_t i = 0; i < n; i++)
{
output.reg[i] = __llvm_amdgcn_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
}
return output.data;
}
template <class T, class Op>
__device__ void dpp_reduce(T& in, Op op)
{
T out{};
out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(2)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out);
out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
in = op(in, out);
}
__device__ inline void dpp_reduce(float& x, sum)
{
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
x = 1;
#else
__asm__ volatile("s_nop 4\n"
"v_add_f32 %0 %0 %0 row_shr:1\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_shr:2\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_shr:4 bank_mask:0xe\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_shr:8 bank_mask:0xc\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_bcast:15 row_mask:0xa\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_bcast:31 row_mask:0xc\n"
"s_nop 1\n"
: "=v"(x)
: "0"(x));
#endif
}
template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{
using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N / 64];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op);
const auto ldsidx = idx.local / 64;
if((idx.local % 64) == 63)
{
buffer[ldsidx] = x;
}
__syncthreads();
type y = init;
for(std::size_t i = 0; i < idx.nlocal() / 64; i++)
{
y = op(y, buffer[i]);
}
return y;
}
#endif
constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{
size_t block_size = 64;
while(block_size < max_block_size and block_size < n)
block_size *= 2;
return block_size;
}
template <class Op, class T, class Input, class Output>
void reduce_multi_impl(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output,
const shape& reduce_slice)
{
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements();
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
auto base_idx = output.get_shape().multi(out_idx);
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
auto reduce_idx = reduce_shape.multi(j);
return read_input(input[reduce_idx + base_idx]);
});
if(idx.local == 0)
output.data()[out_idx] = read_output(r);
});
});
}
template <class Op, class T, class Input, class Output>
void reduce_standard_impl(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output,
std::size_t relements)
{
hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements();
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
const auto base_idx = out_idx * relements;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]);
});
if(idx.local == 0)
output.data()[out_idx] = read_output(r);
});
});
}
template <class Op, class T, class Input, class Output>
void reduce(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output)
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
assert(output_shape.lens().size() == input_shape.lens().size());
if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()),
input_shape.lens().begin()))
{
reduce_standard_impl(
stream, result, arg, op, init, read_input, read_output, input_shape.lens().back());
}
else
{
std::vector<std::size_t> reduce_lens;
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens};
reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice);
}
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment