"src/targets/vscode:/vscode.git/clone" did not exist on "099e9ce8e6f1df3c22e4aa157661034d3f759bf2"
Commit d6b4ae77 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge optimization to print flops branch

parents bdf91961 abe2a889
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatternd : op_parser<parse_scatternd>
{
std::vector<op_desc> operators() const { return {{"ScatterND"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -947,9 +947,9 @@ void program::perf_report(std::ostream& os,
}
void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const
void program::debug_print(instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& ins_names) const
{
std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return is_end(pp.second.end(), ins);
}))
......@@ -965,14 +965,10 @@ void program::debug_print(instruction_ref ins) const
return;
}
std::stringstream ss;
this->print(names, [&](auto x, auto ins_names) {
if(x == ins)
{
instruction::print(std::cout, x, ins_names);
std::cout << std::endl;
}
});
if(contains(ins_names, ins))
{
instruction::print(std::cout, ins, ins_names);
}
}
void program::debug_print(std::ostream& os,
......@@ -1078,11 +1074,12 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name();
});
transform_if(m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
transform_if(
m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
}
std::vector<const module*> program::get_modules() const
......
......@@ -303,86 +303,90 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("name", &migraphx::operation::name);
m.def("parse_tf",
[](const std::string& filename,
bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename,
migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
},
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def("parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def("parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def("load",
[](const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::load(name, options);
},
"Load MIGraphX program",
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("save",
[](const migraphx::program& p, const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::save(p, name, options);
},
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def(
"parse_tf",
[](const std::string& filename,
bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename, migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
},
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def(
"parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def(
"parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def(
"load",
[](const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::load(name, options);
},
"Load MIGraphX program",
py::arg("filename"),
py::arg("format") = "msgpack");
m.def(
"save",
[](const migraphx::program& p, const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::save(p, name, options);
},
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("get_target", &migraphx::make_target);
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
......
......@@ -39,9 +39,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size())
{
}
while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1;
}
......
......@@ -86,6 +86,8 @@ struct shape_impl
return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
};
const std::vector<shape::type_t>& shape::types()
......@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm)
......@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
return this->with_lens(this->type(), l);
}
shape shape::with_type(type_t t) const
{
auto c = impl->copy();
c->m_type = t;
return {c};
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); }
......
......@@ -335,7 +335,6 @@ struct find_concat_op
}
auto y = p.insert_instruction(ins, op, concats);
return {y};
};
std::vector<instruction_ref> args;
......
......@@ -327,7 +327,6 @@ struct find_nested_concat
else
args.push_back(i);
}
})(ins->inputs());
p.replace_instruction(ins, ins->get_operator(), args);
}
......
......@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs>
bool is_vectorizable(const Xs&... xs)
{
return all_of({xs...}, [](const auto& s) {
if(s.standard() and (s.lens().back() % N) == 0)
return true;
if(s.broadcasted())
......
......@@ -11,7 +11,7 @@ if(NOT TARGET MIOpen)
endif()
include(Embed)
file(GLOB KERNEL_FILES
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES})
......@@ -133,6 +133,7 @@ add_library(migraphx_gpu
compile_hip_code_object.cpp
compile_pointwise.cpp
compile_roialign.cpp
compile_scatternd.cpp
concat.cpp
convert.cpp
convolution.cpp
......
#include <migraphx/gpu/compile_scatternd.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const scatternd_kernel = R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., REDUCTION);
});
}
}
} // namespace migraphx
int main() {}
)__migraphx__";
operation
compile_scatternd(context&, const std::vector<shape>& io_shapes, const std::string& reduction)
{
hip_compile_options options;
auto out_s = io_shapes.back();
options.local = 1024;
options.global = compute_global(io_shapes.at(1).elements(), options.local);
options.inputs = io_shapes;
options.output = out_s;
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = io_shapes;
options.params += " -DREDUCTION=assign_" + reduction + "{}";
return compile_hip_code_object(scatternd_kernel, options);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 2)
{
auto stride = ss.at(1).strides();
return (stride[0] == 0);
}
return false;
}
__global__ void add_kernel(void* a, void* b, int n_dim, void* r, int n)
{
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
int idb = tid % n_dim;
hr[tid] = __hadd2(ha[tid], hb[idb]);
}
}
void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; });
auto sr = result.get_shape();
std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
if(sr.type() == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
add_kernel<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; });
}
}
void add(hipStream_t stream,
......
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/permutation.hpp>
#include <hip/hip_fp16.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
__global__ void
cont_kernel(void* in, void* out, int os1, int os2, int os3, int is1, int is2, int is3)
{
int i1 = blockIdx.x;
int i2 = blockIdx.y;
int i3 = blockIdx.z;
int i4 = threadIdx.x;
__half* in_ptr = reinterpret_cast<__half*>(in);
__half* out_ptr = reinterpret_cast<__half*>(out);
int out_idx = i1 * os1 + i2 * os2 + i3 * os3 + i4;
int in_idx = i1 * is1 + i2 * is2 + i3 * is3 + i4;
out_ptr[out_idx] = in_ptr[in_idx];
}
void contiguous_nonstandard(hipStream_t stream, const argument& result, const argument& arg)
{
shape s{result.get_shape().type(), result.get_shape().lens()};
// auto in_s = arg.get_shape();
// auto perm = find_permutation(in_s);
// if (in_s.type() == shape::half_type and perm == std::vector<int64_t>({0, 2, 1, 3}))
// {
// auto lens = s.lens();
// auto last_dim = s.lens().back();
// dim3 grid(lens[0], lens[1], lens[2]);
// dim3 block(last_dim);
// auto in_stride = in_s.strides();
// auto out_stride = s.strides();
// cont_kernel<<<grid, block, 0, stream>>>(arg.data(), result.data(), out_stride[0],
// out_stride[1], out_stride[2], in_stride[0], in_stride[1], in_stride[2]);
// }
// else
// {
visit_all(result, arg)([&](auto output_v, auto input_v) {
hip_visit_views(output_v, input_v, s)([&](auto output, auto input, auto standard_shape) {
mi_gs_launch(stream,
standard_shape)([=](auto idx) __device__ { output[idx] = input[idx]; });
});
});
// }
}
void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg)
......
#include <migraphx/gpu/device/gelu.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cmath>
namespace migraphx {
......@@ -32,15 +34,69 @@ void gelu_new(hipStream_t stream, const argument& result, const argument& arg)
nary(stream, result, arg)([](auto x) __device__ { return gelu_fn_new(to_hip_type(x)); });
}
static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 2)
{
auto stride = ss.at(1).strides();
return (stride[0] == 0);
}
return false;
}
__global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n)
{
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
int idb = tid % n_dim;
auto sum = __hadd2(ha[tid], hb[idb]);
__half2 sqrt2 = __float2half2_rn(M_SQRT1_2);
auto x = __hmul2(sum, sqrt2);
auto f2 = __half22float2(x);
f2.x = ::erff(f2.x);
f2.y = ::erff(f2.y);
auto h2 = __floats2half2_rn(f2.x, f2.y);
auto one = __float2half2_rn(1.0f);
h2 = __hadd2(h2, one);
__half2 point5 = __float2half2_rn(0.5f);
hr[tid] = __hmul2(sum, __hmul2(point5, h2));
}
}
void add_gelu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ {
auto sum = to_hip_type(x + y);
return gelu_fn(sum);
});
auto sr = result.get_shape();
auto type = sr.type();
std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
if(type == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
add_gelu_kernel<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ {
auto sum = to_hip_type(x + y);
return gelu_fn(sum);
});
}
}
void add_gelu_new(hipStream_t stream,
......
......@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
{
assert(s.standard);
assert(s.elements() > 0);
index_int n = s.elements();
index_int groups = (n + nlocal - 1) / nlocal;
index_int nglobal = std::min<index_int>(128, groups) * nlocal;
index_int n = s.elements();
index_int groups = (n + nlocal - 1) / nlocal;
// change the max group num to 1 Million
index_int nglobal = std::min<index_int>((1 << 20), groups) * nlocal;
assert(groups > 0);
assert(nglobal > 0);
......
......@@ -24,6 +24,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
if(enabled(MIGRAPHX_TRACE_NARY{})) \
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
static index_int group_num_global = (1 << 8);
template <class... Ts>
constexpr auto pack(Ts... xs)
{
......@@ -87,7 +89,7 @@ void nary_broadcast_vec_impl(
const index_int vec_size = 4;
const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal;
const index_int nglobal = group_num_global * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)(
[&](auto output, auto binput, auto... inputs) {
......@@ -134,7 +136,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal;
const index_int nglobal = group_num_global * nlocal;
index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
......@@ -178,7 +180,7 @@ void nary_double_broadcast_vec_impl(
const index_int vec_size = 4;
const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal;
const index_int nglobal = group_num_global * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
......@@ -234,7 +236,7 @@ void nary_double_broadcast_impl(
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal;
const index_int nglobal = group_num_global * nlocal;
index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
......
......@@ -44,12 +44,13 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input,
template <index_int N, class Op, class T, class Input, class Output>
__device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output)
{
block_scan<N>(idx,
op,
init,
[&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); },
input,
output);
block_scan<N>(
idx,
op,
init,
[&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); },
input,
output);
}
} // namespace device
......
......@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f)
{
switch(n)
{
case 1:
{
case 1: {
f(std::integral_constant<index_int, 1>{});
break;
}
case 2:
{
case 2: {
f(std::integral_constant<index_int, 2>{});
break;
}
case 3:
{
case 3: {
f(std::integral_constant<index_int, 3>{});
break;
}
case 4:
{
case 4: {
f(std::integral_constant<index_int, 4>{});
break;
}
case 5:
{
case 5: {
f(std::integral_constant<index_int, 5>{});
break;
}
......
......@@ -184,11 +184,11 @@ auto layernorm_fusion(hipStream_t stream,
const Arguments&... args)
{
return [=](auto input, auto output) {
auto relements = arg1.get_shape().lens().back();
auto nelements = result.get_shape().elements() / relements;
auto output_shape = result.get_shape();
auto reduce_output_lens(output_shape.lens());
reduce_output_lens.back() = 1;
auto relements = arg1.get_shape().lens().back();
auto nelements = result.get_shape().elements() / relements;
// auto output_shape = result.get_shape();
// auto reduce_output_lens(output_shape.lens());
// reduce_output_lens.back() = 1;
if((relements % 4) == 0)
layernorm_vec_impl<4>(
......
#include <migraphx/gpu/device/mul.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 2)
{
auto stride = ss.at(1).strides();
return (stride[0] == 0);
}
return false;
}
__global__ void mul_kernel(void* a, void* b, int n_dim, void* r, int n)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; });
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
int idb = tid % n_dim;
hr[tid] = __hmul2(ha[tid], hb[idb]);
}
}
void mul(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z)
__device__ { return x * y * z; });
auto sr = result.get_shape();
std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
if(sr.type() == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
mul_kernel<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; });
}
}
} // namespace device
......
#include "migraphx/gpu/device/launch.hpp"
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/amd_detail/amd_hip_runtime.h>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
__global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r, int n)
{
int id = blockDim.x * blockIdx.x + threadIdx.x;
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r);
if(id < n)
{
auto id1 = id % dim3;
hr[id] = __hfma2(ha[id], hx[id1], hb[id1]);
}
}
__global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int dim4, void* r, int n)
{
int id = blockDim.x * blockIdx.x + threadIdx.x;
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r);
if(id < n)
{
int idb = id / (factor * dim4) * dim4 + id % dim4;
hr[id] = __hfma2(ha[id], hx[id], hb[idb]);
}
}
static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 3)
{
auto stride = ss.at(2).strides();
return (stride[1] == 0);
}
else if(n_dim == 2)
{
auto stride1 = ss.at(1).strides();
auto stride2 = ss.at(2).strides();
return (stride1 == stride2 and stride1[0] == 0);
}
return false;
}
void mul_add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
__device__ { return a * x + b; });
auto sr = result.get_shape();
auto type = sr.type();
std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
ss.push_back(arg3.get_shape());
auto lens = sr.lens();
int last_dim = lens.back() / 2;
auto n_dim = lens.size();
if(type == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
if(n_dim == 2)
{
mul_add_kernel_dim3<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
}
else
{
int factor = lens[1];
mul_add_kernel_dim4<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
}
}
else
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
__device__ { return a * x + b; });
}
}
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment