Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
...@@ -40,9 +40,8 @@ struct hip_prefix_scan_sum : oper<hip_prefix_scan_sum> ...@@ -40,9 +40,8 @@ struct hip_prefix_scan_sum : oper<hip_prefix_scan_sum>
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
if(op.exclusive or op.reverse) device::prefix_scan_sum(
MIGRAPHX_THROW("Exclusive and reverse scan not supported"); ctx.get_stream().get(), args[1], args[0], op.axis, op.exclusive, op.reverse);
device::prefix_scan_sum(ctx.get_stream().get(), args[1], args[0], op.axis);
return args[1]; return args[1];
} }
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void kernel(${params})
{
pointwise(${lambda}, ${args});
}
}
} // namespace migraphx
)__migraphx__";
struct pointwise_compiler : compiler<pointwise_compiler>
{
std::vector<std::string> names() const { return {"pointwise"}; }
static std::size_t oversubscribe(const std::vector<shape>& inputs)
{
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); }))
return 1;
else
return 4;
}
static std::size_t vectorize_elements(const std::vector<shape>& inputs)
{
std::size_t n = inputs.front().elements();
if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.packed() or s.broadcasted();
}))
{
if((n % 4) == 0)
return n / 4;
else if((n % 2) == 0)
return n / 2;
}
return n;
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, vectorize_elements(inputs), oversubscribe(inputs)));
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()},
{"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign",
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
return replace(
compile_op(ctx, to_shapes(ins->inputs()), {{"lambda", lambda}, {"preamble", g.str()}}));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compile_roialign.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp> #include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -37,46 +43,46 @@ __global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y ...@@ -37,46 +43,46 @@ __global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y
} // namespace migraphx } // namespace migraphx
int main() {}
)__migraphx__"; )__migraphx__";
operation compile_roialign(context&, const std::vector<shape>& io_shapes, const value& val) struct roialign_compiler : compiler<roialign_compiler>
{ {
hip_compile_options options; std::vector<std::string> names() const { return {"roialign"}; }
auto out_s = io_shapes.back();
options.local = 128; operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
options.global = compute_global(out_s.elements(), options.local); {
options.inputs = io_shapes; hip_compile_options options;
options.output = out_s; options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), 128);
options.kernel_name = "roialign_kernel"; options.output = inputs.back();
options.virtual_inputs = io_shapes; options.inputs = inputs;
options.kernel_name = "roialign_kernel";
// sampling_ratio
assert(val.contains("sampling_ratio")); // sampling_ratio
auto sampling_ratio = val.at("sampling_ratio").to<int64_t>(); options.params += " -DSAMPLING_RATIO=" + v.at("sampling_ratio").to<std::string>();
options.params += " -DSAMPLING_RATIO=" + std::to_string(sampling_ratio);
// pooling_mode
// pooling_mode auto mode = v.at("mode").to<migraphx::op::pooling_mode>();
assert(val.contains("mode")); std::string is_avg_pooling =
auto mode = val.at("mode").to<std::string>(); (mode == migraphx::op::pooling_mode::average) ? "true" : "false";
bool is_avg_pooling = (mode == "avg"); options.params += " -DIS_AVG_POOLING=" + is_avg_pooling;
options.params += " -DIS_AVG_POOLING=" + std::to_string(static_cast<int>(is_avg_pooling));
// coord_trans_mode
// coord_trans_mode auto ctm = v.at("coordinate_transformation_mode").to<std::string>();
assert(val.contains("coordinate_transformation_mode")); float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f;
auto ctm = val.at("coordinate_transformation_mode").to<std::string>(); options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset);
float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f;
options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset); // spatial_scale
options.params += " -DSPATIAL_SCALE=" + v.at("spatial_scale").to<std::string>();
// spatial_scale
assert(val.contains("spatial_scale")); return compile_hip_code_object(roialign_kernel, options);
float spatial_scale = val.at("spatial_scale").to<float>(); }
options.params += " -DSPATIAL_SCALE=" + std::to_string(spatial_scale);
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
return compile_hip_code_object(roialign_kernel, options); {
} return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
} // namespace gpu }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const scatternd_kernel = R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., ${reduction}{});
});
}
}
} // namespace migraphx
)__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler>
{
std::vector<std::string> names() const
{
return {"scatternd_none", "scatternd_add", "scatternd_mul"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
auto out_s = inputs.back();
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(ctx,
to_shapes({ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& op) const
{
return [=](module& m, instruction_ref ins) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
};
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -21,6 +21,16 @@ struct greater ...@@ -21,6 +21,16 @@ struct greater
} }
}; };
template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{
while(first != last)
{
*d_first++ = *first++;
}
return d_first;
}
template <class Iterator, class Compare> template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp) constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{ {
......
...@@ -48,7 +48,7 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=) ...@@ -48,7 +48,7 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(==) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(==)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP (^) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||)
...@@ -70,5 +70,11 @@ using index_constant = integral_constant<index_int, N>; ...@@ -70,5 +70,11 @@ using index_constant = integral_constant<index_int, N>;
template <auto V> template <auto V>
static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT
template <class F>
constexpr auto return_c(F f)
{
return _c<f()>;
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP #endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#define MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
template <class F, class Iterator = diff_int>
struct basic_iota_iterator
{
Iterator index;
F f;
using difference_type = diff_int;
using reference = decltype(f(std::declval<Iterator>()));
using value_type = remove_reference_t<reference>;
using pointer = add_pointer_t<value_type>;
constexpr basic_iota_iterator& operator+=(diff_int n)
{
index += n;
return *this;
}
constexpr basic_iota_iterator& operator-=(diff_int n)
{
index -= n;
return *this;
}
constexpr basic_iota_iterator& operator++()
{
index++;
return *this;
}
constexpr basic_iota_iterator& operator--()
{
index--;
return *this;
}
constexpr basic_iota_iterator operator++(int) // NOLINT
{
basic_iota_iterator it = *this;
index++;
return it;
}
constexpr basic_iota_iterator operator--(int) // NOLINT
{
basic_iota_iterator it = *this;
index--;
return it;
}
// TODO: operator->
constexpr reference operator*() const { return f(index); }
template <class T>
constexpr reference operator[](T x) const
{
return f(index + x);
}
};
template <class T, class F>
constexpr basic_iota_iterator<F, T> make_basic_iota_iterator(T x, F f)
{
return basic_iota_iterator<F, T>{x, f};
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x += y;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(diff_int x, basic_iota_iterator<F, Iterator> y)
{
return y + x;
}
template <class F, class Iterator>
constexpr diff_int operator-(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index - y.index;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x -= y;
}
template <class F, class Iterator>
constexpr bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator>
constexpr bool operator!=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
template <class F, class Iterator>
constexpr bool operator<(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index < y.index;
}
template <class F, class Iterator>
constexpr bool operator>(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index > y.index;
}
template <class F, class Iterator>
constexpr bool operator>=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index >= y.index;
}
template <class F, class Iterator>
constexpr bool operator<=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index <= y.index;
}
struct defaul_iota_iterator
{
template <class T>
constexpr auto operator()(T x) const
{
return x;
}
};
using iota_iterator = basic_iota_iterator<defaul_iota_iterator>;
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
...@@ -59,6 +59,7 @@ MIGRAPHX_DEVICE_MATH(cosh, ::cosh) ...@@ -59,6 +59,7 @@ MIGRAPHX_DEVICE_MATH(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH(erf, ::erf) MIGRAPHX_DEVICE_MATH(erf, ::erf)
MIGRAPHX_DEVICE_MATH(exp, ::exp) MIGRAPHX_DEVICE_MATH(exp, ::exp)
MIGRAPHX_DEVICE_MATH(floor, ::floor) MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(log, ::log) MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow) MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(round, ::round) MIGRAPHX_DEVICE_MATH(round, ::round)
...@@ -103,6 +104,7 @@ MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos) ...@@ -103,6 +104,7 @@ MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh) MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor) MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round) MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin) MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
...@@ -129,6 +131,7 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh) ...@@ -129,6 +131,7 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf) MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp) MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor) MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log) MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(pow) MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(round) MIGRAPHX_DEVICE_MATH_VEC(round)
......
...@@ -39,10 +39,8 @@ template <class F, class T, class... Ts> ...@@ -39,10 +39,8 @@ template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
preload<typename T::type>(idx, xs...)([&](auto... ps) { preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) { idx.global_stride(out.get_shape().elements(),
auto multi_idx = out.get_shape().multi(i); [&](auto i) { out[i] = implicit_conversion(f(ps[i]...)); });
out[multi_idx] = implicit_conversion(f(ps[multi_idx]...));
});
}); });
} }
......
...@@ -6,15 +6,32 @@ ...@@ -6,15 +6,32 @@
namespace migraphx { namespace migraphx {
template <class T>
struct remove_vec_impl
{
using type = T;
};
template <class T, index_int N>
struct remove_vec_impl<vec<T, N>>
{
using type = T;
};
template <class T>
using remove_vec = typename remove_vec_impl<T>::type;
template <class T, class... Shapes> template <class T, class... Shapes>
constexpr auto traverse_preload(Shapes... ss) constexpr auto traverse_preload(Shapes... ss)
{ {
return [=](auto f, auto... g) { return [=](auto f, auto... g) {
index_int offset = 0; index_int offset = 0;
auto each = [&](auto x) { auto each = [&](auto x) {
using type = remove_vec<typename decltype(x)::type>;
constexpr auto s = decltype(x.get_shape()){}; constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>; constexpr auto size = s.element_space();
if constexpr(not s.broadcasted() or (s.elements() - size) < 64) if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or
not is_same<T, type>{})
return f(x, offset, false_type{}); return f(x, offset, false_type{});
else else
{ {
...@@ -78,23 +95,23 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) ...@@ -78,23 +95,23 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
invoke); invoke);
} }
template <class T> template <class T, class Shape>
struct remove_vec struct shape_type : Shape
{ {
using type = T; using type = T;
}; };
template <class T, index_int N> template <class T>
struct remove_vec<vec<T, N>> constexpr auto make_shape_type(T)
{ {
using type = T; return shape_type<typename T::type, typename T::shape_type>{};
}; }
template <class T, class... Ts> template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs) __device__ auto preload(index idx, Ts... xs)
{ {
using type = typename remove_vec<T>::type; using type = remove_vec<T>;
constexpr auto size = decltype(compute_preload_size<type>(xs.get_shape()...)){}; constexpr auto size = decltype(compute_preload_size<type>(make_shape_type(xs)...)){};
const index_int max_size = 512 * sizeof(type); const index_int max_size = 512 * sizeof(type);
return [=](auto f) { return [=](auto f) {
if constexpr(size > 0 and size < max_size) if constexpr(size > 0 and size < max_size)
......
...@@ -19,7 +19,7 @@ struct max_pool ...@@ -19,7 +19,7 @@ struct max_pool
} }
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int)
{ {
return (x); return (x);
} }
...@@ -36,21 +36,19 @@ struct avg_pool ...@@ -36,21 +36,19 @@ struct avg_pool
} }
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t y) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{ {
return (y == 0) ? 0.0 : (x / y); return (y == 0) ? 0.0 : (x / y);
} }
}; };
template <class T, class Op> template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data, MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
const array<std::size_t, 2>& dims, const Iterator data, const array<index_int, 2>& dims, array<float, 2> xy, Op pooling)
array<float, 2> xy,
Op pooling)
{ {
array<int, 2> low{}; array<int, 2> low{};
array<int, 2> high{}; array<int, 2> high{};
for(std::size_t ii = 0; ii < xy.size(); ++ii) for(index_int ii = 0; ii < xy.size(); ++ii)
{ {
if(xy[ii] < -1.0f or xy[ii] > dims[ii]) if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{ {
...@@ -65,36 +63,36 @@ MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data, ...@@ -65,36 +63,36 @@ MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data,
xy[ii] = high[ii] = low[ii] = dims[ii] - 1; xy[ii] = high[ii] = low[ii] = dims[ii] - 1;
} }
} }
array<std::size_t, 4> locs = {low[0] * dims[1] + low[1], array<index_int, 4> locs = {low[0] * dims[1] + low[1],
low[0] * dims[1] + high[1], low[0] * dims[1] + high[1],
high[0] * dims[1] + low[1], high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]}; high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0]; float ly = xy[0] - low[0];
float lx = xy[1] - low[1]; float lx = xy[1] - low[1];
float hy = 1.0f - ly; float hy = 1.0f - ly;
float hx = 1.0f - lx; float hx = 1.0f - lx;
array<T, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx}; array<typename Iterator::value_type, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]); auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]); auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23); return pooling(v01, v23);
} }
template <class T, class Op> template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T calc_pooling(const T*& data, MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
const array<float, 2>& roi_starts, const array<float, 2>& roi_starts,
const array<float, 2>& bin_size, const array<float, 2>& bin_size,
const array<int, 2>& idx, const array<int, 2>& idx,
const array<std::size_t, 2>& bin_grid_size, const array<index_int, 2>& bin_grid_size,
const array<std::size_t, 2>& dims, const array<index_int, 2>& dims,
float roi_offset, float roi_offset,
Op op) Op op)
{ {
T output_val = op.init(); typename Iterator::value_type output_val = op.init();
const int64_t count = bin_grid_size[0] * bin_grid_size[1]; const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) { dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<std::size_t, 2> id = {iy, ix}; array<index_int, 2> id = {iy, ix};
array<float, 2> locs = array<float, 2> locs =
roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset; roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset;
...@@ -122,19 +120,19 @@ constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs) ...@@ -122,19 +120,19 @@ constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs)
template <class T, class U, class V, class W, class Settings> template <class T, class U, class V, class W, class Settings>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s) __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s)
{ {
auto index = make_index(); auto index = make_index();
const auto* x = x_t.data(); const auto x = x_t.begin();
const auto* rois = rois_t.data(); const auto rois = rois_t.begin();
const auto* ind = ind_t.data(); const auto ind = ind_t.begin();
auto* out_ptr = y_t.data(); auto out_ptr = y_t.begin();
// input shape // input shape
auto x_lens = x_t.get_shape().lens; auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1]; auto channel_num = x_lens[1];
// input dims of height and width, in all 2-dim arrays, the first dim // input dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width // is for height and second dim is for width
array<std::size_t, 2> in_dims = {x_lens[2], x_lens[3]}; array<index_int, 2> in_dims = {x_lens[2], x_lens[3]};
const auto stride = index.nglobal(); const auto stride = index.nglobal();
auto out_s = y_t.get_shape(); auto out_s = y_t.get_shape();
...@@ -142,8 +140,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -142,8 +140,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
// output dims of height and width, in all 2-dim arrays, the first dim // output dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width // is for height and second dim is for width
const auto& out_lens = out_s.lens; const auto& out_lens = out_s.lens;
array<std::size_t, 2> out_dims = {out_lens[2], out_lens[3]}; array<index_int, 2> out_dims = {out_lens[2], out_lens[3]};
for(index_int i = index.global; i < out_s.elements(); i += stride) for(index_int i = index.global; i < out_s.elements(); i += stride)
{ {
...@@ -153,8 +151,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -153,8 +151,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
int ph = idx[2]; int ph = idx[2];
int pw = idx[3]; int pw = idx[3];
const auto* offset_rois = rois + (n * roi_column_num); const auto offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n]; const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale, array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale,
offset_rois[0] * s.spatial_scale}; offset_rois[0] * s.spatial_scale};
...@@ -163,9 +161,9 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -163,9 +161,9 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
array<float, 2> roi_size{}; array<float, 2> roi_size{};
array<float, 2> bin_size{}; array<float, 2> bin_size{};
array<std::size_t, 2> bin_grid_size{}; array<index_int, 2> bin_grid_size{};
for(std::size_t ii = 0; ii < roi_size.size(); ++ii) for(index_int ii = 0; ii < roi_size.size(); ++ii)
{ {
roi_size[ii] = roi_ends[ii] - roi_starts[ii]; roi_size[ii] = roi_ends[ii] - roi_starts[ii];
roi_size[ii] = max(roi_size[ii], 1.0f); roi_size[ii] = max(roi_size[ii], 1.0f);
...@@ -175,7 +173,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -175,7 +173,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
(s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]); (s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]);
} }
const auto* offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
if constexpr(s.is_avg_pooling) if constexpr(s.is_avg_pooling)
{ {
out_ptr[i] = calc_pooling(offset_x, out_ptr[i] = calc_pooling(offset_x,
......
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x += y;
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x *= y;
}
};
template <class T, class U, class V, class F>
__device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f)
{
auto index = make_index();
auto updates_shape = updates_t.get_shape();
index.global_stride(updates_shape.elements(), [&](auto i) {
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto k = indices_shape.lens.back();
auto q = indices_shape.lens.size();
auto updates_idx = updates_shape.multi(i);
auto indices_idx = indices_shape.multi(0);
copy(updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());
auto index_start = indices_t.begin() + indices_shape.index(indices_idx);
auto index_end = index_start + k;
auto out_idx = output_shape.multi(0);
copy(index_start, index_end, out_idx.begin());
copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
f(output_t[out_idx], updates_t[i]);
});
}
} // namespace migraphx
#endif
...@@ -17,35 +17,38 @@ struct shape ...@@ -17,35 +17,38 @@ struct shape
constexpr shape(Lens l, Strides s) : lens(l), strides(s) {} constexpr shape(Lens l, Strides s) : lens(l), strides(s) {}
constexpr index_int elements() const { return lens.product(); } constexpr auto elements() const { return _c<Lens{}.product()>; }
constexpr index_int element_space() const { return strides.dot(lens - 1) + 1; } constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
constexpr bool packed() const { return elements() == element_space(); } constexpr auto packed() const { return elements() == element_space(); }
constexpr bool broadcasted() const { return strides.product() == 0; } constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; }
constexpr bool transposed() const constexpr auto transposed() const
{ {
if(broadcasted()) return return_c([] {
{ auto lstrides = Strides{};
index_array s; if(shape{}.broadcasted())
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{ {
if(strides[i] != 0) index_array s{};
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{ {
s[j] = strides[i]; if(lstrides[i] != 0)
j++; {
s[j] = lstrides[i];
j++;
}
} }
return not is_sorted(s.begin(), s.begin() + j, greater{});
} }
return not is_sorted(s.begin(), s.begin() + j, greater{}); else
} {
else return not is_sorted(lstrides.begin(), lstrides.end(), greater{});
{ }
return not is_sorted(strides.begin(), strides.end(), greater{}); });
}
} }
constexpr bool standard() const { return packed() and not transposed(); } constexpr auto standard() const { return packed() and not transposed(); }
constexpr index_int index(index_array x) const { return x.dot(strides); } constexpr index_int index(index_array x) const { return x.dot(strides); }
...@@ -63,10 +66,10 @@ struct shape ...@@ -63,10 +66,10 @@ struct shape
return i; return i;
else else
{ {
const index_int rank = this->lens.size(); const auto rank = this->lens.size();
index_int s = 1; index_int s = 1;
index_int result = 0; index_int result = 0;
for(index_int j = 0; j < this->lens.size(); j++) for(index_int j = 0; j < rank; j++)
{ {
const index_int k = rank - j - 1; const index_int k = rank - j - 1;
const index_int stride = this->strides[k]; const index_int stride = this->strides[k];
......
...@@ -3,16 +3,30 @@ ...@@ -3,16 +3,30 @@
#include <migraphx/kernels/shape.hpp> #include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp>
namespace migraphx { namespace migraphx {
template <class T>
struct tensor_view_iterator_read
{
T* view;
constexpr auto& operator()(std::size_t n) const
{
MIGRAPHX_ASSERT(view != nullptr);
return (*view)[n];
}
};
template <class T, class Shape> template <class T, class Shape>
struct tensor_view struct tensor_view
{ {
using type = T; using type = T;
using shape_type = Shape;
using iterator = basic_iota_iterator<tensor_view_iterator_read<const tensor_view>, index_int>;
constexpr Shape get_shape() const { return Shape{}; } constexpr Shape get_shape() const { return Shape{}; }
constexpr index_int size() const { return get_shape().elements(); } constexpr auto size() const { return get_shape().elements(); }
template <class U> template <class U>
constexpr T& operator[](U i) const constexpr T& operator[](U i) const
...@@ -23,8 +37,8 @@ struct tensor_view ...@@ -23,8 +37,8 @@ struct tensor_view
constexpr T* data() const { return x; } constexpr T* data() const { return x; }
constexpr T* begin() const { return data(); } constexpr auto begin() const { return iterator{0, {this}}; }
constexpr T* end() const { return data() + size(); } constexpr auto end() const { return iterator{this->size(), {this}}; }
template <class U> template <class U>
constexpr tensor_view<U, Shape> with(U* y) const constexpr tensor_view<U, Shape> with(U* y) const
......
...@@ -6,6 +6,12 @@ ...@@ -6,6 +6,12 @@
namespace migraphx { namespace migraphx {
template <class T>
struct type_identity
{
using type = T;
};
template <bool B, class T = void> template <bool B, class T = void>
struct enable_if struct enable_if
{ {
...@@ -25,6 +31,43 @@ struct is_convertible : bool_constant<__is_convertible(From, To)> ...@@ -25,6 +31,43 @@ struct is_convertible : bool_constant<__is_convertible(From, To)>
{ {
}; };
template <class T, class U>
struct is_same : false_type
{
};
template <class T>
struct is_same<T, T> : true_type
{
};
template <class T>
struct remove_reference
{
using type = T;
};
template <class T>
struct remove_reference<T&>
{
using type = T;
};
template <class T>
struct remove_reference<T&&>
{
using type = T;
};
template <class T>
using remove_reference_t = typename remove_reference<T>::type;
template <class T>
struct add_pointer : type_identity<typename remove_reference<T>::type*>
{
};
template <class T>
using add_pointer_t = typename add_pointer<T>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__> #define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx } // namespace migraphx
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
namespace migraphx { namespace migraphx {
using index_int = std::uint32_t; using index_int = std::uint32_t;
using diff_int = std::int32_t;
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT #define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
......
...@@ -66,15 +66,18 @@ __device__ __host__ auto as_vec(T* x) ...@@ -66,15 +66,18 @@ __device__ __host__ auto as_vec(T* x)
return reinterpret_cast<vec<T, N>*>(x); return reinterpret_cast<vec<T, N>*>(x);
} }
template <class T, index_int N>
using safe_vec = vec<std::conditional_t<std::is_same<T, bool>{}, uint8_t, T>, N>;
template <class... Ts> template <class... Ts>
constexpr auto vec_transform(Ts... xs) constexpr auto vec_transform(Ts... xs)
{ {
return [=](auto f) { return [=](auto f) {
if constexpr(is_any_vec<Ts...>()) if constexpr(is_any_vec<Ts...>())
{ {
using type = decltype(f(vec_at(xs, 0)...)); using type = decltype(f(vec_at(xs, 0)...));
constexpr auto size = common_vec_size<Ts...>(); constexpr auto size = common_vec_size<Ts...>();
vec<type, size> result = {0}; safe_vec<type, size> result = {0};
for(int i = 0; i < size; i++) for(int i = 0; i < size; i++)
result[i] = f(vec_at(xs, i)...); result[i] = f(vec_at(xs, i)...);
return result; return result;
......
...@@ -50,14 +50,14 @@ constexpr auto shape_step(Shape s, Axis) ...@@ -50,14 +50,14 @@ constexpr auto shape_step(Shape s, Axis)
}); });
} }
// Bools can not be used as a vector type so convert it to int8 // Bools can not be used as a vector type so convert it to uint8
template <class T> template <class T>
__device__ __host__ T* remove_bool(T* x) __device__ __host__ T* remove_bool(T* x)
{ {
return x; return x;
} }
inline __device__ __host__ int8_t* remove_bool(bool* x) { return reinterpret_cast<int8_t*>(x); } inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis) __device__ __host__ auto as_vec(T x, Axis axis)
......
...@@ -20,10 +20,10 @@ ...@@ -20,10 +20,10 @@
#include <migraphx/gpu/abs.hpp> #include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp> #include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/compile_roialign.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp> #include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/elu.hpp> #include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp> #include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/unary_not.hpp> #include <migraphx/gpu/unary_not.hpp>
#include <migraphx/gpu/where.hpp> #include <migraphx/gpu/where.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <utility> #include <utility>
...@@ -60,6 +61,7 @@ struct miopen_apply ...@@ -60,6 +61,7 @@ struct miopen_apply
std::unordered_map<instruction_ref, std::string> prog_output_names{}; std::unordered_map<instruction_ref, std::string> prog_output_names{};
bool offload_copy = false; bool offload_copy = false;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false;
context& get_context() const context& get_context() const
{ {
...@@ -96,13 +98,22 @@ struct miopen_apply ...@@ -96,13 +98,22 @@ struct miopen_apply
} }
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
void init() void init()
{ {
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 #if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto& ctx = get_context(); auto& ctx = get_context();
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag; rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag); rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4); int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
...@@ -183,8 +194,6 @@ struct miopen_apply ...@@ -183,8 +194,6 @@ struct miopen_apply
add_extend_op("softmax"); add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
add_precompile_op("pointwise");
add_batch_norm_inference_op(); add_batch_norm_inference_op();
add_convolution_op(); add_convolution_op();
add_deconvolution_op(); add_deconvolution_op();
...@@ -195,7 +204,6 @@ struct miopen_apply ...@@ -195,7 +204,6 @@ struct miopen_apply
add_neg_op(); add_neg_op();
add_nms_op(); add_nms_op();
add_quant_convolution_op(); add_quant_convolution_op();
add_roialign();
} }
void copy_params() void copy_params()
...@@ -249,11 +257,28 @@ struct miopen_apply ...@@ -249,11 +257,28 @@ struct miopen_apply
{ {
check_shape(s, apply_map.at(it->name())(it)); check_shape(s, apply_map.at(it->name())(it));
} }
else if(has_compiler_for(it->name()))
{
check_shape(s, insert_precompile_op(it));
}
} }
copy_params(); copy_params();
} }
instruction_ref insert_precompile_op(instruction_ref ins)
{
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return mod->replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
refs,
ins->module_inputs());
}
instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "") instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "")
{ {
// Instruction's output is an input of the ret instruction // Instruction's output is an input of the ret instruction
...@@ -337,7 +362,7 @@ struct miopen_apply ...@@ -337,7 +362,7 @@ struct miopen_apply
} }
} }
return mod->replace_instruction( return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format}, refs); ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format, compute_fp32}, refs);
}); });
} }
...@@ -383,21 +408,6 @@ struct miopen_apply ...@@ -383,21 +408,6 @@ struct miopen_apply
}); });
} }
void add_precompile_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return mod->replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
refs,
ins->module_inputs());
});
}
void add_batch_norm_inference_op() void add_batch_norm_inference_op()
{ {
apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) { apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) {
...@@ -432,7 +442,6 @@ struct miopen_apply ...@@ -432,7 +442,6 @@ struct miopen_apply
reshapes[2], reshapes[2],
reshapes[3], reshapes[3],
output); output);
}); });
} }
...@@ -489,22 +498,6 @@ struct miopen_apply ...@@ -489,22 +498,6 @@ struct miopen_apply
}); });
} }
void add_roialign()
{
apply_map.emplace("roialign", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
auto co = compile_roialign(get_context(), io_shapes, op_val);
return mod->replace_instruction(ins, co, args);
});
}
// replace the loop operator with gpu_loop operator // replace the loop operator with gpu_loop operator
void add_loop_op() void add_loop_op()
{ {
......
...@@ -45,7 +45,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -45,7 +45,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_POINTWISE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
struct id_pass struct id_pass
{ {
...@@ -101,7 +101,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -101,7 +101,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
fuse_pointwise{}, enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{}, dead_code_elimination{},
fuse_mlir{&ctx}, fuse_mlir{&ctx},
dead_code_elimination{}, dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment