Commit d6b4ae77 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge optimization to print flops branch

parents bdf91961 abe2a889
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatternd : op_parser<parse_scatternd>
{
std::vector<op_desc> operators() const { return {{"ScatterND"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -947,9 +947,9 @@ void program::perf_report(std::ostream& os, ...@@ -947,9 +947,9 @@ void program::perf_report(std::ostream& os,
} }
void program::debug_print() const { std::cout << *this << std::endl; } void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const void program::debug_print(instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& ins_names) const
{ {
std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) { if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return is_end(pp.second.end(), ins); return is_end(pp.second.end(), ins);
})) }))
...@@ -965,14 +965,10 @@ void program::debug_print(instruction_ref ins) const ...@@ -965,14 +965,10 @@ void program::debug_print(instruction_ref ins) const
return; return;
} }
std::stringstream ss; if(contains(ins_names, ins))
this->print(names, [&](auto x, auto ins_names) { {
if(x == ins) instruction::print(std::cout, ins, ins_names);
{ }
instruction::print(std::cout, x, ins_names);
std::cout << std::endl;
}
});
} }
void program::debug_print(std::ostream& os, void program::debug_print(std::ostream& os,
...@@ -1078,11 +1074,12 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera ...@@ -1078,11 +1074,12 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) { std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name(); return mod->name();
}); });
transform_if(m.begin(), transform_if(
m.end(), m.begin(),
out, m.end(),
[&](auto&& pp) { return not contains(used, pp.first); }, out,
[](auto&& pp) { return &pp.second; }); [&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
} }
std::vector<const module*> program::get_modules() const std::vector<const module*> program::get_modules() const
......
...@@ -303,86 +303,90 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -303,86 +303,90 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("name", &migraphx::operation::name); .def("name", &migraphx::operation::name);
m.def("parse_tf", m.def(
[](const std::string& filename, "parse_tf",
bool is_nhwc, [](const std::string& filename,
unsigned int batch_size, bool is_nhwc,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, unsigned int batch_size,
std::vector<std::string> output_names) { std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
return migraphx::parse_tf( std::vector<std::string> output_names) {
filename, return migraphx::parse_tf(
migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names}); filename, migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
}, },
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true, py::arg("is_nhwc") = true,
py::arg("batch_size") = 1, py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>()); py::arg("output_names") = std::vector<std::string>());
m.def("parse_onnx", m.def(
[](const std::string& filename, "parse_onnx",
unsigned int default_dim_value, [](const std::string& filename,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, unsigned int default_dim_value,
bool skip_unknown_operators, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool print_program_on_error, bool skip_unknown_operators,
int64_t max_loop_iterations) { bool print_program_on_error,
migraphx::onnx_options options; int64_t max_loop_iterations) {
options.default_dim_value = default_dim_value; migraphx::onnx_options options;
options.map_input_dims = map_input_dims; options.default_dim_value = default_dim_value;
options.skip_unknown_operators = skip_unknown_operators; options.map_input_dims = map_input_dims;
options.print_program_on_error = print_program_on_error; options.skip_unknown_operators = skip_unknown_operators;
options.max_loop_iterations = max_loop_iterations; options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx(filename, options); options.max_loop_iterations = max_loop_iterations;
}, return migraphx::parse_onnx(filename, options);
"Parse onnx file", },
py::arg("filename"), "Parse onnx file",
py::arg("default_dim_value") = 1, py::arg("filename"),
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("default_dim_value") = 1,
py::arg("skip_unknown_operators") = false, py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("print_program_on_error") = false, py::arg("skip_unknown_operators") = false,
py::arg("max_loop_iterations") = 10); py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def("parse_onnx_buffer",
[](const std::string& onnx_buffer, m.def(
unsigned int default_dim_value, "parse_onnx_buffer",
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, [](const std::string& onnx_buffer,
bool skip_unknown_operators, unsigned int default_dim_value,
bool print_program_on_error) { std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
migraphx::onnx_options options; bool skip_unknown_operators,
options.default_dim_value = default_dim_value; bool print_program_on_error) {
options.map_input_dims = map_input_dims; migraphx::onnx_options options;
options.skip_unknown_operators = skip_unknown_operators; options.default_dim_value = default_dim_value;
options.print_program_on_error = print_program_on_error; options.map_input_dims = map_input_dims;
return migraphx::parse_onnx_buffer(onnx_buffer, options); options.skip_unknown_operators = skip_unknown_operators;
}, options.print_program_on_error = print_program_on_error;
"Parse onnx file", return migraphx::parse_onnx_buffer(onnx_buffer, options);
py::arg("filename"), },
py::arg("default_dim_value") = 1, "Parse onnx file",
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("filename"),
py::arg("skip_unknown_operators") = false, py::arg("default_dim_value") = 1,
py::arg("print_program_on_error") = false); py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
m.def("load", py::arg("print_program_on_error") = false);
[](const std::string& name, const std::string& format) {
migraphx::file_options options; m.def(
options.format = format; "load",
return migraphx::load(name, options); [](const std::string& name, const std::string& format) {
}, migraphx::file_options options;
"Load MIGraphX program", options.format = format;
py::arg("filename"), return migraphx::load(name, options);
py::arg("format") = "msgpack"); },
"Load MIGraphX program",
m.def("save", py::arg("filename"),
[](const migraphx::program& p, const std::string& name, const std::string& format) { py::arg("format") = "msgpack");
migraphx::file_options options;
options.format = format; m.def(
return migraphx::save(p, name, options); "save",
}, [](const migraphx::program& p, const std::string& name, const std::string& format) {
"Save MIGraphX program", migraphx::file_options options;
py::arg("p"), options.format = format;
py::arg("filename"), return migraphx::save(p, name, options);
py::arg("format") = "msgpack"); },
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("get_target", &migraphx::make_target); m.def("get_target", &migraphx::make_target);
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
......
...@@ -39,9 +39,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -39,9 +39,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n) std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{ {
while(reduce_dim(shapes, n) and n < shapes.size()) while(reduce_dim(shapes, n) and n < shapes.size()) {}
{
}
return n + 1; return n + 1;
} }
......
...@@ -86,6 +86,8 @@ struct shape_impl ...@@ -86,6 +86,8 @@ struct shape_impl
return std::accumulate( return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
}; };
const std::vector<shape::type_t>& shape::types() const std::vector<shape::type_t>& shape::types()
...@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
shape shape::from_permutation(type_t t, shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l, const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm) const std::vector<int64_t>& perm)
...@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const ...@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
return this->with_lens(this->type(), l); return this->with_lens(this->type(), l);
} }
shape shape::with_type(type_t t) const
{
auto c = impl->copy();
c->m_type = t;
return {c};
}
std::size_t shape::element_space() const { return impl->element_space(); } std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); } std::string shape::type_string() const { return name(this->type()); }
......
...@@ -335,7 +335,6 @@ struct find_concat_op ...@@ -335,7 +335,6 @@ struct find_concat_op
} }
auto y = p.insert_instruction(ins, op, concats); auto y = p.insert_instruction(ins, op, concats);
return {y}; return {y};
}; };
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
......
...@@ -327,7 +327,6 @@ struct find_nested_concat ...@@ -327,7 +327,6 @@ struct find_nested_concat
else else
args.push_back(i); args.push_back(i);
} }
})(ins->inputs()); })(ins->inputs());
p.replace_instruction(ins, ins->get_operator(), args); p.replace_instruction(ins, ins->get_operator(), args);
} }
......
...@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs> ...@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs>
bool is_vectorizable(const Xs&... xs) bool is_vectorizable(const Xs&... xs)
{ {
return all_of({xs...}, [](const auto& s) { return all_of({xs...}, [](const auto& s) {
if(s.standard() and (s.lens().back() % N) == 0) if(s.standard() and (s.lens().back() % N) == 0)
return true; return true;
if(s.broadcasted()) if(s.broadcasted())
......
...@@ -11,7 +11,7 @@ if(NOT TARGET MIOpen) ...@@ -11,7 +11,7 @@ if(NOT TARGET MIOpen)
endif() endif()
include(Embed) include(Embed)
file(GLOB KERNEL_FILES file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES}) add_embed_library(migraphx_kernels ${KERNEL_FILES})
...@@ -133,6 +133,7 @@ add_library(migraphx_gpu ...@@ -133,6 +133,7 @@ add_library(migraphx_gpu
compile_hip_code_object.cpp compile_hip_code_object.cpp
compile_pointwise.cpp compile_pointwise.cpp
compile_roialign.cpp compile_roialign.cpp
compile_scatternd.cpp
concat.cpp concat.cpp
convert.cpp convert.cpp
convolution.cpp convolution.cpp
......
#include <migraphx/gpu/compile_scatternd.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const scatternd_kernel = R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., REDUCTION);
});
}
}
} // namespace migraphx
int main() {}
)__migraphx__";
operation
compile_scatternd(context&, const std::vector<shape>& io_shapes, const std::string& reduction)
{
hip_compile_options options;
auto out_s = io_shapes.back();
options.local = 1024;
options.global = compute_global(io_shapes.at(1).elements(), options.local);
options.inputs = io_shapes;
options.output = out_s;
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = io_shapes;
options.params += " -DREDUCTION=assign_" + reduction + "{}";
return compile_hip_code_object(scatternd_kernel, options);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/gpu/device/nary.hpp> #include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 2)
{
auto stride = ss.at(1).strides();
return (stride[0] == 0);
}
return false;
}
__global__ void add_kernel(void* a, void* b, int n_dim, void* r, int n)
{
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
int idb = tid % n_dim;
hr[tid] = __hadd2(ha[tid], hb[idb]);
}
}
void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; }); auto sr = result.get_shape();
std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
if(sr.type() == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
add_kernel<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; });
}
} }
void add(hipStream_t stream, void add(hipStream_t stream,
......
#include <migraphx/gpu/device/contiguous.hpp> #include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/nary.hpp> #include <migraphx/gpu/device/nary.hpp>
#include <migraphx/permutation.hpp>
#include <hip/hip_fp16.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
__global__ void
cont_kernel(void* in, void* out, int os1, int os2, int os3, int is1, int is2, int is3)
{
int i1 = blockIdx.x;
int i2 = blockIdx.y;
int i3 = blockIdx.z;
int i4 = threadIdx.x;
__half* in_ptr = reinterpret_cast<__half*>(in);
__half* out_ptr = reinterpret_cast<__half*>(out);
int out_idx = i1 * os1 + i2 * os2 + i3 * os3 + i4;
int in_idx = i1 * is1 + i2 * is2 + i3 * is3 + i4;
out_ptr[out_idx] = in_ptr[in_idx];
}
void contiguous_nonstandard(hipStream_t stream, const argument& result, const argument& arg) void contiguous_nonstandard(hipStream_t stream, const argument& result, const argument& arg)
{ {
shape s{result.get_shape().type(), result.get_shape().lens()}; shape s{result.get_shape().type(), result.get_shape().lens()};
// auto in_s = arg.get_shape();
// auto perm = find_permutation(in_s);
// if (in_s.type() == shape::half_type and perm == std::vector<int64_t>({0, 2, 1, 3}))
// {
// auto lens = s.lens();
// auto last_dim = s.lens().back();
// dim3 grid(lens[0], lens[1], lens[2]);
// dim3 block(last_dim);
// auto in_stride = in_s.strides();
// auto out_stride = s.strides();
// cont_kernel<<<grid, block, 0, stream>>>(arg.data(), result.data(), out_stride[0],
// out_stride[1], out_stride[2], in_stride[0], in_stride[1], in_stride[2]);
// }
// else
// {
visit_all(result, arg)([&](auto output_v, auto input_v) { visit_all(result, arg)([&](auto output_v, auto input_v) {
hip_visit_views(output_v, input_v, s)([&](auto output, auto input, auto standard_shape) { hip_visit_views(output_v, input_v, s)([&](auto output, auto input, auto standard_shape) {
mi_gs_launch(stream, mi_gs_launch(stream,
standard_shape)([=](auto idx) __device__ { output[idx] = input[idx]; }); standard_shape)([=](auto idx) __device__ { output[idx] = input[idx]; });
}); });
}); });
// }
} }
void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg) void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg)
......
#include <migraphx/gpu/device/gelu.hpp> #include <migraphx/gpu/device/gelu.hpp>
#include <migraphx/gpu/device/nary.hpp> #include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cmath> #include <cmath>
namespace migraphx { namespace migraphx {
...@@ -32,15 +34,69 @@ void gelu_new(hipStream_t stream, const argument& result, const argument& arg) ...@@ -32,15 +34,69 @@ void gelu_new(hipStream_t stream, const argument& result, const argument& arg)
nary(stream, result, arg)([](auto x) __device__ { return gelu_fn_new(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return gelu_fn_new(to_hip_type(x)); });
} }
static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 2)
{
auto stride = ss.at(1).strides();
return (stride[0] == 0);
}
return false;
}
__global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n)
{
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
int idb = tid % n_dim;
auto sum = __hadd2(ha[tid], hb[idb]);
__half2 sqrt2 = __float2half2_rn(M_SQRT1_2);
auto x = __hmul2(sum, sqrt2);
auto f2 = __half22float2(x);
f2.x = ::erff(f2.x);
f2.y = ::erff(f2.y);
auto h2 = __floats2half2_rn(f2.x, f2.y);
auto one = __float2half2_rn(1.0f);
h2 = __hadd2(h2, one);
__half2 point5 = __float2half2_rn(0.5f);
hr[tid] = __hmul2(sum, __hmul2(point5, h2));
}
}
void add_gelu(hipStream_t stream, void add_gelu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2) const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { auto sr = result.get_shape();
auto sum = to_hip_type(x + y); auto type = sr.type();
return gelu_fn(sum); std::vector<shape> ss;
}); ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
if(type == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
add_gelu_kernel<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ {
auto sum = to_hip_type(x + y);
return gelu_fn(sum);
});
}
} }
void add_gelu_new(hipStream_t stream, void add_gelu_new(hipStream_t stream,
......
...@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal) ...@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
{ {
assert(s.standard); assert(s.standard);
assert(s.elements() > 0); assert(s.elements() > 0);
index_int n = s.elements(); index_int n = s.elements();
index_int groups = (n + nlocal - 1) / nlocal; index_int groups = (n + nlocal - 1) / nlocal;
index_int nglobal = std::min<index_int>(128, groups) * nlocal; // change the max group num to 1 Million
index_int nglobal = std::min<index_int>((1 << 20), groups) * nlocal;
assert(groups > 0); assert(groups > 0);
assert(nglobal > 0); assert(nglobal > 0);
......
...@@ -24,6 +24,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY); ...@@ -24,6 +24,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
if(enabled(MIGRAPHX_TRACE_NARY{})) \ if(enabled(MIGRAPHX_TRACE_NARY{})) \
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl; std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
static index_int group_num_global = (1 << 8);
template <class... Ts> template <class... Ts>
constexpr auto pack(Ts... xs) constexpr auto pack(Ts... xs)
{ {
...@@ -87,7 +89,7 @@ void nary_broadcast_vec_impl( ...@@ -87,7 +89,7 @@ void nary_broadcast_vec_impl(
const index_int vec_size = 4; const index_int vec_size = 4;
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = group_num_global * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size; const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)( hip_vec_visit_all<vec_size>(result, barg, args...)(
[&](auto output, auto binput, auto... inputs) { [&](auto output, auto binput, auto... inputs) {
...@@ -134,7 +136,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -134,7 +136,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = group_num_global * nlocal;
index_int nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) { hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
...@@ -178,7 +180,7 @@ void nary_double_broadcast_vec_impl( ...@@ -178,7 +180,7 @@ void nary_double_broadcast_vec_impl(
const index_int vec_size = 4; const index_int vec_size = 4;
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = group_num_global * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size; const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)( hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) { [&](auto output, auto binput1, auto binput2, auto... inputs) {
...@@ -234,7 +236,7 @@ void nary_double_broadcast_impl( ...@@ -234,7 +236,7 @@ void nary_double_broadcast_impl(
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = group_num_global * nlocal;
index_int nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)( hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) { [&](auto output, auto binput1, auto binput2, auto... inputs) {
......
...@@ -44,12 +44,13 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, ...@@ -44,12 +44,13 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input,
template <index_int N, class Op, class T, class Input, class Output> template <index_int N, class Op, class T, class Input, class Output>
__device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output) __device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output)
{ {
block_scan<N>(idx, block_scan<N>(
op, idx,
init, op,
[&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); }, init,
input, [&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); },
output); input,
output);
} }
} // namespace device } // namespace device
......
...@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f) ...@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f)
{ {
switch(n) switch(n)
{ {
case 1: case 1: {
{
f(std::integral_constant<index_int, 1>{}); f(std::integral_constant<index_int, 1>{});
break; break;
} }
case 2: case 2: {
{
f(std::integral_constant<index_int, 2>{}); f(std::integral_constant<index_int, 2>{});
break; break;
} }
case 3: case 3: {
{
f(std::integral_constant<index_int, 3>{}); f(std::integral_constant<index_int, 3>{});
break; break;
} }
case 4: case 4: {
{
f(std::integral_constant<index_int, 4>{}); f(std::integral_constant<index_int, 4>{});
break; break;
} }
case 5: case 5: {
{
f(std::integral_constant<index_int, 5>{}); f(std::integral_constant<index_int, 5>{});
break; break;
} }
......
...@@ -184,11 +184,11 @@ auto layernorm_fusion(hipStream_t stream, ...@@ -184,11 +184,11 @@ auto layernorm_fusion(hipStream_t stream,
const Arguments&... args) const Arguments&... args)
{ {
return [=](auto input, auto output) { return [=](auto input, auto output) {
auto relements = arg1.get_shape().lens().back(); auto relements = arg1.get_shape().lens().back();
auto nelements = result.get_shape().elements() / relements; auto nelements = result.get_shape().elements() / relements;
auto output_shape = result.get_shape(); // auto output_shape = result.get_shape();
auto reduce_output_lens(output_shape.lens()); // auto reduce_output_lens(output_shape.lens());
reduce_output_lens.back() = 1; // reduce_output_lens.back() = 1;
if((relements % 4) == 0) if((relements % 4) == 0)
layernorm_vec_impl<4>( layernorm_vec_impl<4>(
......
#include <migraphx/gpu/device/mul.hpp> #include <migraphx/gpu/device/mul.hpp>
#include <migraphx/gpu/device/nary.hpp> #include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 2)
{
auto stride = ss.at(1).strides();
return (stride[0] == 0);
}
return false;
}
__global__ void mul_kernel(void* a, void* b, int n_dim, void* r, int n)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; }); __half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
int idb = tid % n_dim;
hr[tid] = __hmul2(ha[tid], hb[idb]);
}
} }
void mul(hipStream_t stream, void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) auto sr = result.get_shape();
__device__ { return x * y * z; }); std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
if(sr.type() == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
mul_kernel<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; });
}
} }
} // namespace device } // namespace device
......
#include "migraphx/gpu/device/launch.hpp"
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/amd_detail/amd_hip_runtime.h>
#include <migraphx/gpu/device/mul_add.hpp> #include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/nary.hpp> #include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
__global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r, int n)
{
int id = blockDim.x * blockIdx.x + threadIdx.x;
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r);
if(id < n)
{
auto id1 = id % dim3;
hr[id] = __hfma2(ha[id], hx[id1], hb[id1]);
}
}
__global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int dim4, void* r, int n)
{
int id = blockDim.x * blockIdx.x + threadIdx.x;
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r);
if(id < n)
{
int idb = id / (factor * dim4) * dim4 + id % dim4;
hr[id] = __hfma2(ha[id], hx[id], hb[idb]);
}
}
static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 3)
{
auto stride = ss.at(2).strides();
return (stride[1] == 0);
}
else if(n_dim == 2)
{
auto stride1 = ss.at(1).strides();
auto stride2 = ss.at(2).strides();
return (stride1 == stride2 and stride1[0] == 0);
}
return false;
}
void mul_add(hipStream_t stream, void mul_add(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2, const argument& arg2,
const argument& arg3) const argument& arg3)
{ {
nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b) auto sr = result.get_shape();
__device__ { return a * x + b; }); auto type = sr.type();
std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
ss.push_back(arg3.get_shape());
auto lens = sr.lens();
int last_dim = lens.back() / 2;
auto n_dim = lens.size();
if(type == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
if(n_dim == 2)
{
mul_add_kernel_dim3<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
}
else
{
int factor = lens[1];
mul_add_kernel_dim4<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
}
}
else
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
__device__ { return a * x + b; });
}
} }
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment