Unverified Commit faefeef9 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Merge branch 'develop' into dyn_shape_update

parents 97a40ac3 bf0a4713
...@@ -17,9 +17,9 @@ struct schedule_model ...@@ -17,9 +17,9 @@ struct schedule_model
{ {
std::size_t streams = 0; std::size_t streams = 0;
std::size_t concurrency() const; std::size_t concurrency() const;
void sched(module& p, instruction_ref ins, std::size_t n) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
void wait(module& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
void record(module& p, instruction_ref ins, std::size_t wait_id) const; void record(module& m, instruction_ref ins, std::size_t wait_id) const;
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
......
...@@ -15,7 +15,7 @@ namespace gpu { ...@@ -15,7 +15,7 @@ namespace gpu {
struct sync_device struct sync_device
{ {
std::string name() const { return "sync_device"; } std::string name() const { return "sync_device"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -14,7 +14,7 @@ struct write_literals ...@@ -14,7 +14,7 @@ struct write_literals
context* ctx = nullptr; context* ctx = nullptr;
std::string name() const { return "gpu::write_literals"; } std::string name() const { return "gpu::write_literals"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -19,7 +19,7 @@ namespace gpu { ...@@ -19,7 +19,7 @@ namespace gpu {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const gathernd_kernel = R"__migraphx__( static const char* const gathernd_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp> #include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/basic_ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
...@@ -26,9 +27,10 @@ namespace migraphx { ...@@ -26,9 +27,10 @@ namespace migraphx {
${preamble} ${preamble}
extern "C" { extern "C" {
__global__ void kernel(${params}) __global__ void ${kernel}(${params})
{ {
pointwise(${lambda}, ${args}); auto idx = make_index();
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
} }
} }
...@@ -37,44 +39,123 @@ __global__ void kernel(${params}) ...@@ -37,44 +39,123 @@ __global__ void kernel(${params})
)__migraphx__"; )__migraphx__";
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise"}; } std::vector<std::string> names() const { return {"pointwise"}; }
static std::size_t oversubscribe(const std::vector<shape>& inputs) static std::size_t oversubscribe_if(bool b)
{ {
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); })) if(b)
return 1;
else
return 256; return 256;
else
return 1;
} }
static std::size_t vectorize_elements(const std::vector<shape>& inputs) static std::size_t find_fast_axis(const std::vector<shape>& inputs)
{ {
std::size_t n = inputs.front().elements(); auto permutation = find_permutation(inputs);
auto it = std::max_element(permutation.begin(), permutation.end());
return it - permutation.begin();
}
static std::vector<bool> preload(std::size_t axis, const std::vector<shape>& inputs)
{
const std::size_t max_lds_bytes = 4096;
std::vector<bool> result;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(result),
[&](const shape& input) { return input.strides()[axis] == 0; });
auto bytes = std::inner_product(inputs.begin(),
inputs.end(),
result.begin(),
std::size_t{0},
std::plus<>{},
[](const shape& s, bool b) -> std::size_t {
if(b)
return s.bytes();
return 0;
});
if(bytes < max_lds_bytes)
return result;
// TODO: Try to partially preload items
std::fill(result.begin(), result.end(), false);
return result;
}
static std::string preload_str(const std::vector<bool>& bs)
{
std::vector<std::string> bool_strs;
std::transform(bs.begin(), std::prev(bs.end()), std::back_inserter(bool_strs), [](bool b) {
if(b)
return "true";
return "false";
});
return "false, " + join_strings(bool_strs, ", ");
}
static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
{
// If all inputs is half then only use half2
if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) { if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.packed() or s.broadcasted(); return s.type() == shape::half_type;
})) }))
{ return {2};
if((n % 4) == 0) return {4, 2};
return n / 4;
else if((n % 2) == 0)
return n / 2;
} }
return n; static auto vectorize_elements(std::size_t axis, const std::vector<shape>& inputs)
{
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(max_vec_size),
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
return *it;
return 1;
});
return *std::min_element(max_vec_size.begin(), max_vec_size.end());
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, vectorize_elements(inputs), oversubscribe(inputs)));
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs);
auto vec_size = vectorize_elements(axis, options.virtual_inputs);
auto preloads = preload(axis, options.virtual_inputs);
auto is_preloading =
std::accumulate(preloads.begin(), preloads.end(), false, std::logical_or<>{});
options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params(v,
compute_global_for(ctx,
options.output.elements() / vec_size,
oversubscribe_if(not is_preloading)));
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"vec_size", std::to_string(vec_size)},
{"axis", std::to_string(axis)},
{"preloads", preload_str(preloads)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
...@@ -100,8 +181,13 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -100,8 +181,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto name = g.create_function( auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm)); g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")"; std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace( return replace(
compile_op(ctx, to_shapes(ins->inputs()), {{"lambda", lambda}, {"preamble", g.str()}})); compile_op(ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
} }
}; };
} // namespace gpu } // namespace gpu
......
...@@ -19,7 +19,6 @@ namespace gpu { ...@@ -19,7 +19,6 @@ namespace gpu {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const roialign_kernel = R"__migraphx__( static const char* const roialign_kernel = R"__migraphx__(
#include <migraphx/kernels/roialign.hpp> #include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
......
...@@ -19,7 +19,6 @@ namespace gpu { ...@@ -19,7 +19,6 @@ namespace gpu {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const scatternd_kernel = R"__migraphx__( static const char* const scatternd_kernel = R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp> #include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
......
...@@ -146,8 +146,8 @@ struct array ...@@ -146,8 +146,8 @@ struct array
constexpr array carry(array result) const constexpr array carry(array result) const
{ {
uint32_t overflow = 0; index_int overflow = 0;
for(std::ptrdiff_t i = result.size() - 1; i > 0; i--) for(diff_int i = result.size() - 1; i > 0; i--)
{ {
auto z = result[i] + overflow; auto z = result[i] + overflow;
// Reset overflow // Reset overflow
......
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#include <migraphx/kernels/types.hpp>
namespace migraphx {
struct sum
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
{
return x + y;
}
};
struct product
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
{
return x * y;
}
};
struct id
{
template <class T>
constexpr auto operator()(T x) const
{
return x;
}
};
struct mean
{
size_t item_num = 1;
template <class T>
constexpr auto operator()(T x) const
{
return x / static_cast<T>(item_num);
}
};
struct max_f
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
{
return (x > y) ? x : y;
}
};
inline constexpr auto max = max_f{};
struct min_f
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
{
return (x < y) ? x : y;
}
};
inline constexpr auto min = min_f{};
struct lowest
{
template <class T>
constexpr operator T() const
{
return std::numeric_limits<T>::lowest();
}
};
struct highest
{
template <class T>
constexpr operator T() const
{
return std::numeric_limits<T>::max();
}
};
} // namespace migraphx
#endif // MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
...@@ -3,6 +3,14 @@ ...@@ -3,6 +3,14 @@
#include <migraphx/kernels/array.hpp> #include <migraphx/kernels/array.hpp>
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
namespace migraphx { namespace migraphx {
struct swallow struct swallow
...@@ -129,7 +137,7 @@ constexpr auto by(F f) ...@@ -129,7 +137,7 @@ constexpr auto by(F f)
template <class F, class... Ts> template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs) constexpr void each_args(F f, Ts&&... xs)
{ {
swallow{(f(std::forward<Ts>(xs)), 0)...}; swallow{(f(static_cast<Ts&&>(xs)), 0)...};
} }
template <class F> template <class F>
...@@ -161,6 +169,18 @@ constexpr auto pack(Ts... xs) ...@@ -161,6 +169,18 @@ constexpr auto pack(Ts... xs)
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
template <class G, class F>
constexpr auto join(G g, F f)
{
return f([=](auto... xs) { return g(xs...); });
}
template <class G, class F, class... Fs>
constexpr auto join(G g, F f, Fs... fs)
{
return f([=](auto... xs) { return join([=](auto... ys) { return g(xs..., ys...); }, fs...); });
}
template <class Compare, class P1, class P2> template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2) constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{ {
...@@ -191,39 +211,45 @@ constexpr auto arg(IntegralConstant ic) ...@@ -191,39 +211,45 @@ constexpr auto arg(IntegralConstant ic)
return arg_c<ic>(); return arg_c<ic>();
} }
inline constexpr auto rotate_last() template <class F>
constexpr auto make_transform(F f)
{ {
return [](auto... xs) { return [=](auto... xs) { return [=](auto g) { return f(g, xs...); }; };
return [=](auto&& f) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
};
};
} }
// An arg transformation takes the arguments and then a function to take the new arguments:
// transform(xs...)([](auto... ys) { ... })
// The transform_args function takes a list of transformations and continually applies them
template <class F> template <class F>
constexpr auto transform_args(F f) constexpr auto transform_args(F f)
{ {
return [=](auto... xs) { return f;
return [=](auto g) { return f(xs...)([&](auto... ys) { return g(ys...); }); };
};
} }
template <class F, class... Fs> template <class F, class... Fs>
constexpr auto transform_args(F f, Fs... fs) constexpr auto transform_args(F f, Fs... fs)
{ {
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); }; return make_transform([=](auto g, auto... xs) {
return f(xs...)([=](auto... ys) { return transform_args(fs...)(ys...)(g); });
});
} }
// NOLINTNEXTLINE // identity transform
#define MIGRAPHX_RETURNS(...) \ inline constexpr auto transform_args()
->decltype(__VA_ARGS__) { return __VA_ARGS__; } {
return make_transform([](auto f, auto... xs) { return f(xs...); });
}
// NOLINTNEXTLINE // Rotate the first argument to the last argument
#define MIGRAPHX_LIFT(...) \ inline constexpr auto rotate_last()
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...)) {
return make_transform([](auto f, auto... xs) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
});
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
...@@ -13,7 +13,7 @@ struct basic_iota_iterator ...@@ -13,7 +13,7 @@ struct basic_iota_iterator
F f; F f;
using difference_type = diff_int; using difference_type = diff_int;
using reference = decltype(f(std::declval<Iterator>())); using reference = decltype(f(declval<Iterator>()));
using value_type = remove_reference_t<reference>; using value_type = remove_reference_t<reference>;
using pointer = add_pointer_t<value_type>; using pointer = add_pointer_t<value_type>;
......
...@@ -38,20 +38,17 @@ constexpr implicit_conversion_op<T> implicit_conversion(T x) ...@@ -38,20 +38,17 @@ constexpr implicit_conversion_op<T> implicit_conversion(T x)
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), idx.global_stride(out.get_shape().elements(),
[&](auto i) { out[i] = implicit_conversion(f(ps[i]...)); }); [&](auto i) { out[i] = implicit_conversion(f(xs[i]...)); });
});
} }
template <class F, class... Ts> template <class... Transforms>
__device__ void pointwise(F f, Ts*... ps) __device__ auto pointwise(index idx, Transforms... transforms)
{ {
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize()); return [=](auto f, auto*... ps) {
t(ps...)([&](auto... xs) { auto t = transform_args(make_tensors(), rotate_last(), transforms...);
auto idx = make_index(); t(ps...)([&](auto... xs) { pointwise_tensor(idx, f, xs...); });
pointwise_tensor(idx, f, xs...); };
});
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx { namespace migraphx {
...@@ -73,7 +75,7 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) ...@@ -73,7 +75,7 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{ {
if constexpr(decltype(tensor_vec_size(x)){} == 0) if constexpr(decltype(tensor_vec_size(x)){} == 0)
{ {
auto v = vectorize(x); auto v = auto_vectorize(x);
auto b = as_vec(tensor_vec_size(v), buffer + offset); auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(), idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; }); [&](auto i) { b[i] = v.data()[i]; });
...@@ -126,5 +128,47 @@ __device__ auto preload(index idx, Ts... xs) ...@@ -126,5 +128,47 @@ __device__ auto preload(index idx, Ts... xs)
}; };
} }
inline __device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto out, auto... xs) {
preload<typename decltype(out)::type>(idx, xs...)([&](auto... ys) { f(out, ys...); });
});
}
template <bool B, class T>
__device__ auto preload_copy(index idx, T x)
{
return [=](auto f) {
if constexpr(B)
{
using type = typename T::type;
constexpr auto size = get_shape_c<T>{}.element_space();
__shared__ type buffer[size];
// TODO: Always vecotrize when size > 4, and then use a second loop for remainder
constexpr auto n = find_vectorize_size([&](auto i) { return (size % i) == 0; });
auto input = as_vec<n>(remove_bool(x.data()));
auto b = as_vec<n>(remove_bool(buffer));
idx.local_stride(size / n, [&](auto i) { b[i] = input[i]; });
return f(x.with(buffer));
}
else
{
return f(x);
}
};
}
template <bool... Bs>
__device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto... xs) {
auto invoke = [=](auto... ys) {
__syncthreads();
f(ys...);
};
join(invoke, preload_copy<Bs>(idx, xs)...);
});
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP #endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
...@@ -3,14 +3,15 @@ ...@@ -3,14 +3,15 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/dfor.hpp> #include <migraphx/kernels/dfor.hpp>
#include <migraphx/kernels/basic_ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/array.hpp> #include <migraphx/kernels/array.hpp>
namespace migraphx { namespace migraphx {
struct max_pool struct max_pool
{ {
MIGRAPHX_DEVICE_CONSTEXPR auto init() { return lowest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() { return lowest{}; }
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y) MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y)
...@@ -55,7 +56,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -55,7 +56,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
return 0; return 0;
} }
xy[ii] = max(xy[ii], 0.0f); xy[ii] = migraphx::max(xy[ii], 0.0f);
low[ii] = xy[ii]; low[ii] = xy[ii];
high[ii] = low[ii] + 1; high[ii] = low[ii] + 1;
if(low[ii] >= dims[ii] - 1) if(low[ii] >= dims[ii] - 1)
...@@ -164,11 +165,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -164,11 +165,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
for(index_int ii = 0; ii < roi_size.size(); ++ii) for(index_int ii = 0; ii < roi_size.size(); ++ii)
{ {
roi_size[ii] = roi_ends[ii] - roi_starts[ii]; roi_size[ii] = roi_ends[ii] - roi_starts[ii];
roi_size[ii] = max(roi_size[ii], 1.0f); roi_size[ii] = migraphx::max(roi_size[ii], 1.0f);
bin_size[ii] = roi_size[ii] / out_dims[ii]; bin_size[ii] = roi_size[ii] / out_dims[ii];
bin_grid_size[ii] = bin_grid_size[ii] = (s.sampling_ratio > 0)
(s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]); ? s.sampling_ratio
: migraphx::ceil(roi_size[ii] / out_dims[ii]);
} }
const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
......
...@@ -11,7 +11,7 @@ template <class T> ...@@ -11,7 +11,7 @@ template <class T>
struct tensor_view_iterator_read struct tensor_view_iterator_read
{ {
T* view; T* view;
constexpr auto& operator()(std::size_t n) const constexpr auto& operator()(index_int n) const
{ {
MIGRAPHX_ASSERT(view != nullptr); MIGRAPHX_ASSERT(view != nullptr);
return (*view)[n]; return (*view)[n];
......
...@@ -35,6 +35,21 @@ struct enable_if<true, T> ...@@ -35,6 +35,21 @@ struct enable_if<true, T>
template <bool B, class T = void> template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type; using enable_if_t = typename enable_if<B, T>::type;
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
template <bool B, class T, class F>
using conditional_t = typename conditional<B, T, F>::type;
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \ #define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \ template <class T> \
......
...@@ -60,17 +60,26 @@ constexpr auto common_vec_size() ...@@ -60,17 +60,26 @@ constexpr auto common_vec_size()
})(vec_size<Ts>()...); })(vec_size<Ts>()...);
} }
// Bools can not be used as a vector type so convert it to uint8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T> template <index_int N, class T>
__device__ __host__ auto as_vec(T* x) __device__ __host__ auto as_vec(T* x)
{ {
if constexpr(N == 0) if constexpr(N < 2)
return x; return x;
else else
return reinterpret_cast<vec<T, N>*>(x); return reinterpret_cast<vec<T, N>*>(x);
} }
template <class T, index_int N> template <class T, index_int N>
using safe_vec = vec<std::conditional_t<std::is_same<T, bool>{}, uint8_t, T>, N>; using safe_vec = vec<conditional_t<is_same<T, bool>{}, uint8_t, T>, N>;
template <class... Ts> template <class... Ts>
constexpr auto vec_transform(Ts... xs) constexpr auto vec_transform(Ts... xs)
......
...@@ -50,19 +50,10 @@ constexpr auto shape_step(Shape s, Axis) ...@@ -50,19 +50,10 @@ constexpr auto shape_step(Shape s, Axis)
}); });
} }
// Bools can not be used as a vector type so convert it to uint8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis) __device__ __host__ auto as_vec(T x, Axis axis)
{ {
if constexpr(N == 0) if constexpr(N < 2)
return x; return x;
else else
return make_tensor_view(as_vec<N>(remove_bool(x.data())), return make_tensor_view(as_vec<N>(remove_bool(x.data())),
...@@ -72,7 +63,7 @@ __device__ __host__ auto as_vec(T x, Axis axis) ...@@ -72,7 +63,7 @@ __device__ __host__ auto as_vec(T x, Axis axis)
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
constexpr auto tensor_step(T x, Axis axis) constexpr auto tensor_step(T x, Axis axis)
{ {
if constexpr(N == 0) if constexpr(N < 2)
{ {
return x; return x;
} }
...@@ -157,11 +148,11 @@ constexpr auto find_vectorize_size(P pred) ...@@ -157,11 +148,11 @@ constexpr auto find_vectorize_size(P pred)
else if constexpr(decltype(pred(_c<2>)){}) else if constexpr(decltype(pred(_c<2>)){})
return _c<2>; return _c<2>;
else else
return _c<0>; return _c<1>;
} }
template <class T> template <class T>
__host__ __device__ auto vectorize(T x) __host__ __device__ auto auto_vectorize(T x)
{ {
if constexpr(tensor_vec_size<T>() == 0) if constexpr(tensor_vec_size<T>() == 0)
{ {
...@@ -194,7 +185,7 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs) ...@@ -194,7 +185,7 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
{ {
MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1); MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1);
MIGRAPHX_ASSERT(s.lens[axis] > 0); MIGRAPHX_ASSERT(s.lens[axis] > 0);
MIGRAPHX_ASSERT(n == 0 or s.lens[axis] % n == 0); MIGRAPHX_ASSERT(n == 1 or s.lens[axis] % n == 0);
if constexpr(s.strides[axis] == 0) if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis); return tensor_step<n>(x, axis);
else else
...@@ -215,7 +206,32 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs) ...@@ -215,7 +206,32 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
inline __device__ __host__ auto auto_vectorize() inline __device__ __host__ auto auto_vectorize()
{ {
return [](auto... xs) { return [=](auto f) { auto_vectorize_impl(f, xs...); }; }; return make_transform([](auto f, auto... xs) { auto_vectorize_impl(f, xs...); });
}
template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x)
{
constexpr auto shape = get_shape_c<T>{};
if constexpr(shape.strides[Axis] == 0)
return tensor_step<N>(x, _c<Axis>);
else
return as_vec<N>(x, _c<Axis>);
}
template <index_int N, index_int Axis>
__device__ __host__ auto vectorize()
{
return make_transform([](auto f, auto... xs) {
if constexpr(N < 2)
{
f(xs...);
}
else
{
f(vectorize_tensor<N, Axis>(xs)...);
}
});
} }
} // namespace migraphx } // namespace migraphx
......
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace {
struct find_layernorm
{
auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
if(not x_ins->get_shape().standard())
x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
}
};
struct find_triaddlayernorm
{
auto matcher() const
{
auto add1 =
match::name("add")(match::none_of(match::is_constant()),
match::args(match::any().bind("z1"), match::any().bind("z2")));
auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
return match::layernorm()(match::var("x")(add2));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["z1"];
auto y_ins = r.instructions["z2"];
auto z_ins = r.instructions["z3"];
for(auto* pins : {&x_ins, &y_ins, &z_ins})
{
if(not(*pins)->get_shape().standard())
*pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
}
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
}
};
} // namespace
void prefuse_ops::apply(module& m) const
{
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -77,28 +77,28 @@ MIGRAPHX_REGISTER_OP(wait_event) ...@@ -77,28 +77,28 @@ MIGRAPHX_REGISTER_OP(wait_event)
MIGRAPHX_REGISTER_OP(set_stream) MIGRAPHX_REGISTER_OP(set_stream)
std::size_t schedule_model::concurrency() const { return streams; } std::size_t schedule_model::concurrency() const { return streams; }
void schedule_model::sched(module& p, instruction_ref ins, std::size_t n) const void schedule_model::sched(module& m, instruction_ref ins, std::size_t n) const
{ {
auto last_stream = std::find_if(std::make_reverse_iterator(ins), auto last_stream = std::find_if(std::make_reverse_iterator(ins),
std::make_reverse_iterator(p.begin()), std::make_reverse_iterator(m.begin()),
[&](auto&& i) { return i.name() == "gpu::set_stream"; }); [&](auto&& i) { return i.name() == "gpu::set_stream"; });
if(last_stream != std::make_reverse_iterator(p.begin())) if(last_stream != std::make_reverse_iterator(m.begin()))
{ {
auto&& op = any_cast<set_stream>(last_stream->get_operator()); auto&& op = any_cast<set_stream>(last_stream->get_operator());
// If the same stream was set earlier then skip // If the same stream was set earlier then skip
if(op.stream == n) if(op.stream == n)
return; return;
} }
p.insert_instruction(ins, set_stream{n}); m.insert_instruction(ins, set_stream{n});
} }
void schedule_model::wait(module& p, instruction_ref ins, std::size_t wait_id) const void schedule_model::wait(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
p.insert_instruction(ins, wait_event{wait_id}); m.insert_instruction(ins, wait_event{wait_id});
} }
void schedule_model::record(module& p, instruction_ref ins, std::size_t wait_id) const void schedule_model::record(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
p.insert_instruction(std::next(ins), record_event{wait_id}); m.insert_instruction(std::next(ins), record_event{wait_id});
} }
static std::unordered_map<std::string, std::size_t> create_weight_map() static std::unordered_map<std::string, std::size_t> create_weight_map()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment