Commit a392d84c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into bugs_for_bert

parents 5a264889 0628e570
...@@ -14,6 +14,7 @@ add_library(migraphx ...@@ -14,6 +14,7 @@ add_library(migraphx
eliminate_pad.cpp eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
rewrite_pooling.cpp
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp instruction.cpp
......
...@@ -30,23 +30,29 @@ struct binary : op_name<Derived> ...@@ -30,23 +30,29 @@ struct binary : op_name<Derived>
argument result{output_shape}; argument result{output_shape};
auto s1 = args[0].get_shape(); auto s1 = args[0].get_shape();
auto s2 = args[1].get_shape(); auto s2 = args[1].get_shape();
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { if(s1 == s2 and s1.packed())
if(s1 == s2 and input1.get_shape().packed() and input2.get_shape().packed())
{ {
shape std_shape{s1.type(), s1.lens()};
argument std_result{std_shape, result.data()};
argument std_arg0{std_shape, args[0].data()};
argument std_arg1{std_shape, args[1].data()};
visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
input2.begin(), input2.begin(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
});
} }
else else
{ {
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()( output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end())); input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
}); });
}
}); });
}
return result; return result;
} }
......
#ifndef MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct capture
{
std::size_t ins_index;
std::function<void(std::size_t ins_index, std::vector<argument>)> f{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.ins_index, "ins_index"));
}
std::string name() const { return "capture"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const
{
if(f)
{
f(ins_index, args);
}
else
{
MIGRAPHX_THROW("CAPTURE: callback function is not callable!");
}
return args.front();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -59,7 +59,9 @@ struct reshape ...@@ -59,7 +59,9 @@ struct reshape
shape s{inputs.front().type(), rdims}; shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements()) if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Wrong number of elements for reshape"); MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " +
std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
return s; return s;
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -27,26 +27,34 @@ struct unary : op_name<Derived> ...@@ -27,26 +27,34 @@ struct unary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { auto in_shape = args[0].get_shape();
args[0].visit([&](auto input) { if(in_shape.packed())
if(input.get_shape().packed())
{ {
shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()};
argument arg_in{std_in_shape, args[0].data()};
argument arg_out{std_out_shape, result.data()};
arg_out.visit([&](auto output) {
arg_in.visit([&](auto input) {
std::transform(input.begin(), std::transform(input.begin(),
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
return result; });
});
} }
else
{
result.visit([&](auto output) {
args[0].visit([&](auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end())); input(idx.begin(), idx.end()));
}); });
return result;
}); });
}); });
}
return result; return result;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp> #include <migraphx/op/concat.hpp>
......
...@@ -126,6 +126,9 @@ struct program ...@@ -126,6 +126,9 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
private: private:
void assign(const program& p); void assign(const program& p);
......
...@@ -15,6 +15,14 @@ struct program; ...@@ -15,6 +15,14 @@ struct program;
void quantize(program& prog, const std::vector<std::string>& ins_names); void quantize(program& prog, const std::vector<std::string>& ins_names);
void quantize(program& prog); void quantize(program& prog);
// insert the capture operator for the inputs of each operator to be quantized
// to int8
void capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
void capture_arguments(program& prog, const std::vector<std::string>& ins_names);
void capture_arguments(program& prog);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_POOLING_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_POOLING_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Rewrite pooling to reduce_mean
*/
struct rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(program& prog) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -132,7 +132,11 @@ struct tensor_view ...@@ -132,7 +132,11 @@ struct tensor_view
return m_data + this->size(); return m_data + this->size();
} }
std::vector<T> to_vector() const { return std::vector<T>(this->begin(), this->end()); } template <class U = T>
std::vector<U> to_vector() const
{
return std::vector<U>(this->begin(), this->end());
}
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x) friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{ {
......
...@@ -113,6 +113,7 @@ void program::assign(const program& p) ...@@ -113,6 +113,7 @@ void program::assign(const program& p)
impl->instructions.clear(); impl->instructions.clear();
} }
impl->ctx = p.impl->ctx; impl->ctx = p.impl->ctx;
int8_quant_params = p.int8_quant_params;
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
......
...@@ -156,6 +156,7 @@ PYBIND11_MODULE(migraphx, m) ...@@ -156,6 +156,7 @@ PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_shape", &migraphx::program::get_shape) .def("get_shape", &migraphx::program::get_shape)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); }) .def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
...@@ -186,6 +187,11 @@ PYBIND11_MODULE(migraphx, m) ...@@ -186,6 +187,11 @@ PYBIND11_MODULE(migraphx, m)
migraphx::quantize(p, ins_names); migraphx::quantize(p, ins_names);
}); });
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); }); m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); });
m.def("capture_arguments", [](migraphx::program& p, const std::vector<std::string>& ins_names) {
migraphx::capture_arguments(p, ins_names);
});
m.def("capture_arguments", [](migraphx::program& p) { migraphx::capture_arguments(p); });
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
...@@ -3,32 +3,53 @@ ...@@ -3,32 +3,53 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <utility> #include <utility>
#include <iomanip>
#include <fstream>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_fp16(program& prog, instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins, instruction_ref& ins,
shape::type_t type, shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_fp16) std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
if(map_fp16.count(ins) > 0) if(map_ins.count(ins) > 0)
{ {
return map_fp16[ins]; return map_ins[ins];
}
if(ins->name() == "undefined")
{
return ins;
} }
assert(ins->get_shape().type() == shape::float_type || assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type); ins->get_shape().type() == shape::double_type ||
instruction_ref ins_fp16{}; ins->get_shape().type() == shape::int32_type);
ins_fp16 = prog.insert_instruction(std::next(ins), op::convert{type}, ins); instruction_ref quant_ins{};
map_fp16[ins] = ins_fp16; quant_ins = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
map_ins[ins] = quant_ins;
return ins_fp16; return quant_ins;
} }
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it
// is very rare in the area of deeping learning, so we just do a
// truncate of the input to get the fp16.
void quantize(program& prog, const std::vector<std::string>& ins_names) void quantize(program& prog, const std::vector<std::string>& ins_names)
{ {
std::unordered_map<instruction_ref, instruction_ref> map_fp16; std::unordered_map<instruction_ref, instruction_ref> map_fp16;
...@@ -60,7 +81,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -60,7 +81,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
} }
else else
{ {
input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16); input_fp16 = insert_quant_ins(prog, input, shape::half_type, map_fp16);
} }
converted_inputs.push_back(input_fp16); converted_inputs.push_back(input_fp16);
} }
...@@ -79,13 +100,6 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -79,13 +100,6 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
auto op = ins->get_operator(); auto op = ins->get_operator();
auto ins_shape = compute_shape(op, converted_inputs); auto ins_shape = compute_shape(op, converted_inputs);
if(ins_shape.type() != orig_type) if(ins_shape.type() != orig_type)
{
// insert another convert instruction to convert it back
if(ins == std::prev(prog.end()))
{
prog.add_instruction(op::convert{orig_type}, ins);
}
else
{ {
// check the dead code case to avoid assert // check the dead code case to avoid assert
bool output_empty = ins->outputs().empty(); bool output_empty = ins->outputs().empty();
...@@ -96,7 +110,6 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -96,7 +110,6 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
prog.replace_instruction(ins, ins_orig_type); prog.replace_instruction(ins, ins_orig_type);
} }
} }
}
prog.replace_instruction(ins, op, converted_inputs); prog.replace_instruction(ins, op, converted_inputs);
} }
...@@ -104,5 +117,80 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -104,5 +117,80 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
void quantize(program& prog) { quantize(prog, {"all"}); } void quantize(program& prog) { quantize(prog, {"all"}); }
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
void capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
return std::find(op_names.begin(), op_names.end(), name) != op_names.end();
}))
{
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(prog))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
instruction_ref new_ins{};
if(ins_map.count(input) > 0)
{
new_ins = ins_map[input];
}
else
{
new_ins = prog.insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
new_args.push_back(new_ins);
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
}
// set one pair of parameter for each argument
prog.int8_quant_params->resize(num_quant_params, std::make_pair(-1.0f, -1.0f));
}
void capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{
auto calc_quant_params = [&](std::size_t ins_index, std::vector<migraphx::argument> args) {
std::pair<float, float> param_pair{1.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
param_pair.first = 127.0f / max_abs;
(*prog.int8_quant_params)[ins_index] = param_pair;
};
capture_arguments(prog, ins_names, calc_quant_params);
}
void capture_arguments(program& prog)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
capture_arguments(prog, ins_names);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
if(ins->name() != "pooling")
continue;
if(ins->get_shape().lens().size() != 4)
continue;
if(ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode != "average")
continue;
if(op.padding[0] != 0 and op.padding[1] != 0)
continue;
if(op.stride[0] != 1 and op.stride[1] != 1)
continue;
if(s.lens()[2] != op.lengths[0] and s.lens()[3] != op.lengths[1])
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape =
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -261,7 +261,8 @@ struct cpu_quant_convolution ...@@ -261,7 +261,8 @@ struct cpu_quant_convolution
const auto in_ch = group_id * wei_c + k; const auto in_ch = group_id * wei_c + k;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w) if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{ {
acc += input(o, in_ch, in_x, in_y) * weights(w, k, x, y); acc += static_cast<int32_t>(input(o, in_ch, in_x, in_y)) *
weights(w, k, x, y);
} }
}); });
output(o, w, i, j) = acc; output(o, w, i, j) = acc;
...@@ -576,8 +577,7 @@ struct cpu_quant_gemm ...@@ -576,8 +577,7 @@ struct cpu_quant_gemm
} }
// 2 input arguments // 2 input arguments
int32_t beta = 0; migemm(result, arg_0, arg_1, op.alpha, int32_t{0});
migemm(result, arg_0, arg_1, op.alpha, beta);
return result; return result;
} }
......
...@@ -82,6 +82,7 @@ add_library(migraphx_gpu ...@@ -82,6 +82,7 @@ add_library(migraphx_gpu
elu.cpp elu.cpp
pad.cpp pad.cpp
gather.cpp gather.cpp
convert.cpp
lrn.cpp lrn.cpp
schedule_model.cpp schedule_model.cpp
adjust_allocation.cpp adjust_allocation.cpp
......
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_convert::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
}
argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::convert(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,6 +16,12 @@ struct hip_array ...@@ -16,6 +16,12 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; } MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; } MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T& front() { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& front() const { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR T& back() { return d[N - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& back() const { return d[N - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; } MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
......
...@@ -209,28 +209,15 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si ...@@ -209,28 +209,15 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si
} }
template <class Op, class T, class Input, class Output> template <class Op, class T, class Input, class Output>
void reduce(hipStream_t stream, void reduce_multi_impl(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg, const argument& arg,
Op op, Op op,
T init, T init,
Input read_input, Input read_input,
Output read_output) Output read_output,
const shape& reduce_slice)
{ {
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
std::vector<std::size_t> reduce_lens;
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens};
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) { hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements(); auto relements = reduce_slice.elements();
...@@ -250,6 +237,83 @@ void reduce(hipStream_t stream, ...@@ -250,6 +237,83 @@ void reduce(hipStream_t stream,
}); });
} }
template <class Op, class T, class Input, class Output>
void reduce_standard_impl(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output,
std::size_t relements,
std::size_t stride)
{
hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements();
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]);
});
if(idx.local == 0)
output.data()[out_idx] = read_output(r);
});
});
}
template <class Op, class T, class Input, class Output>
void reduce(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output)
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()),
input_shape.lens().begin()))
{
std::size_t stride = std::accumulate(input_shape.strides().begin(),
input_shape.strides().end(),
1,
std::multiplies<size_t>());
reduce_standard_impl(stream,
result,
arg,
op,
init,
read_input,
read_output,
input_shape.lens().back(),
stride);
}
else
{
std::vector<std::size_t> reduce_lens;
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens};
reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice);
}
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -69,6 +69,8 @@ void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument ...@@ -69,6 +69,8 @@ void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument
}); });
} }
void sync_stream(hipStream_t stream) { hipStreamSynchronize(stream); }
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment