Commit 13403ab2 authored by Umang Yadav's avatar Umang Yadav
Browse files

roialign, softmax, pow, acosh, atanh,pad tests are enabled now

parent 1be95870
...@@ -67,7 +67,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t) ...@@ -67,7 +67,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t)
case st::float_type: return dt::f32; case st::float_type: return dt::f32;
case st::int32_type: return dt::s32; case st::int32_type: return dt::s32;
case st::int8_type: return dt::s8; case st::int8_type: return dt::s8;
case st::uint8_type: return dt::u8; case st::uint8_type:
case st::fp8e4m3fnuz_type: return dt::u8; case st::fp8e4m3fnuz_type: return dt::u8;
default: MIGRAPHX_THROW("Unsupported data type"); default: MIGRAPHX_THROW("Unsupported data type");
} }
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#if defined(__clang__) #if defined(__clang__)
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal" #pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wold-style-cast"
#endif // __clang__ #endif // __clang__
#define MIGRAPHX_HIP_DEVICE __device__ #define MIGRAPHX_HIP_DEVICE __device__
...@@ -132,7 +133,7 @@ struct float8 ...@@ -132,7 +133,7 @@ struct float8
// NOTE: ON-DEVICE... always optimal bias // NOTE: ON-DEVICE... always optimal bias
explicit constexpr MIGRAPHX_HIP_DEVICE explicit constexpr MIGRAPHX_HIP_DEVICE
float8(float v, float8(const float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
{ {
...@@ -145,8 +146,7 @@ struct float8 ...@@ -145,8 +146,7 @@ struct float8
#else #else
// DEVICE for non-gfx940 using s/w simulation // DEVICE for non-gfx940 using s/w simulation
explicit constexpr MIGRAPHX_HIP_DEVICE explicit constexpr MIGRAPHX_HIP_DEVICE
#endif float8(const float v,
float8(float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0) uint32_t rng = 0)
{ {
...@@ -175,7 +175,42 @@ struct float8 ...@@ -175,7 +175,42 @@ struct float8
#endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING} #endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING}
} }
} }
#endif // __gfx940___
// Constructor from half
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(const _Float16 v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0)
: float8((float)v, rm, rng)
{
}
// constructor from int
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(const int v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0)
: float8((float)v, rm, rng)
{
}
// constructor from uint
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(const uint32_t v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0)
: float8((float)v, rm, rng)
{
}
// constructor from double
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(const double v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0)
: float8((float)v, rm, rng)
{
}
// constructor from bool
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(const bool v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0)
: float8((float)(v), rm, rng)
{
}
// convert to float // convert to float
// #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if 0 // need constexpr operator(). This version can't be constexpr // NOLINT #if 0 // need constexpr operator(). This version can't be constexpr // NOLINT
...@@ -209,6 +244,8 @@ struct float8 ...@@ -209,6 +244,8 @@ struct float8
return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
} }
inline constexpr explicit MIGRAPHX_HIP_DEVICE operator bool() const { return not is_zero(); }
// check for zero // check for zero
inline MIGRAPHX_HIP_DEVICE constexpr bool is_zero() const inline MIGRAPHX_HIP_DEVICE constexpr bool is_zero() const
{ {
......
...@@ -39,6 +39,7 @@ __device__ void pad(const index& idx, ...@@ -39,6 +39,7 @@ __device__ void pad(const index& idx,
const PadVal& pad_val) const PadVal& pad_val)
{ {
auto output_shape = output.get_shape(); auto output_shape = output.get_shape();
using otype = typename Output::type;
idx.global_stride(output_shape.elements(), [&](auto i) { idx.global_stride(output_shape.elements(), [&](auto i) {
// 1. get current multi-index for output // 1. get current multi-index for output
// 2. get the size of the input to determine input boundaries // 2. get the size of the input to determine input boundaries
...@@ -53,9 +54,9 @@ __device__ void pad(const index& idx, ...@@ -53,9 +54,9 @@ __device__ void pad(const index& idx,
if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) { if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j]; return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
})) }))
output[multi] = pad_val; output[multi] = otype(pad_val);
else else
output[multi] = input[input_idx]; output[multi] = otype(input[input_idx]);
}); });
} }
......
...@@ -392,7 +392,7 @@ struct block ...@@ -392,7 +392,7 @@ struct block
{ {
using max_iterations = decltype(idx.max_local_stride_iterations(n)); using max_iterations = decltype(idx.max_local_stride_iterations(n));
inner_storage<R, max_iterations{}, N> storage; inner_storage<R, max_iterations{}, N> storage;
idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); }); idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = R{f(xs(j, d)...)}; });
return storage; return storage;
} }
}; };
......
...@@ -56,13 +56,13 @@ struct avg_pool ...@@ -56,13 +56,13 @@ struct avg_pool
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y) MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y)
{ {
return x + y; return static_cast<T>(x + y);
} }
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{ {
return (y == 0) ? 0.0 : (x / y); return (y == 0) ? static_cast<T>(0.0) : static_cast<T>(x / y);
} }
}; };
...@@ -70,13 +70,14 @@ template <class Iterator, class Op> ...@@ -70,13 +70,14 @@ template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
const Iterator data, const array<index_int, 2>& dims, array<float, 2> xy, Op pooling) const Iterator data, const array<index_int, 2>& dims, array<float, 2> xy, Op pooling)
{ {
using ret_type = typename Iterator::value_type;
array<int, 2> low{}; array<int, 2> low{};
array<int, 2> high{}; array<int, 2> high{};
for(index_int ii = 0; ii < xy.size(); ++ii) for(index_int ii = 0; ii < xy.size(); ++ii)
{ {
if(xy[ii] < -1.0f or xy[ii] > dims[ii]) if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{ {
return 0; return static_cast<ret_type>(0);
} }
xy[ii] = migraphx::max(xy[ii], 0.0f); xy[ii] = migraphx::max(xy[ii], 0.0f);
...@@ -92,11 +93,14 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -92,11 +93,14 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
high[0] * dims[1] + low[1], high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]}; high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0]; float ly = xy[0] - low[0];
float lx = xy[1] - low[1]; float lx = xy[1] - low[1];
float hy = 1.0f - ly; float hy = 1.0f - ly;
float hx = 1.0f - lx; float hx = 1.0f - lx;
array<typename Iterator::value_type, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx}; array<ret_type, 4> ws = {static_cast<ret_type>(hy * hx),
static_cast<ret_type>(hy * lx),
static_cast<ret_type>(ly * hx),
static_cast<ret_type>(ly * lx)};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]); auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]); auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
...@@ -113,8 +117,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data, ...@@ -113,8 +117,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
float roi_offset, float roi_offset,
Op op) Op op)
{ {
typename Iterator::value_type output_val = op.init(); using in_dtype = typename Iterator::value_type;
const int64_t count = bin_grid_size[0] * bin_grid_size[1]; in_dtype output_val = in_dtype{op.init()};
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) { dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<index_int, 2> id = {iy, ix}; array<index_int, 2> id = {iy, ix};
array<float, 2> locs = array<float, 2> locs =
...@@ -148,7 +153,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -148,7 +153,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto x = x_t.begin(); const auto x = x_t.begin();
const auto rois = rois_t.begin(); const auto rois = rois_t.begin();
const auto ind = ind_t.begin(); const auto ind = ind_t.begin();
using ytype = typename W::type;
// input shape // input shape
auto x_lens = x_t.get_shape().lens; auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1]; auto channel_num = x_lens[1];
...@@ -176,10 +181,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -176,10 +181,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto offset_rois = rois + (n * roi_column_num); const auto offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n]; const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale, array<float, 2> roi_starts = {
offset_rois[0] * s.spatial_scale}; static_cast<float>(offset_rois[1]) * static_cast<float>(s.spatial_scale),
array<float, 2> roi_ends = {offset_rois[3] * s.spatial_scale, static_cast<float>(offset_rois[0]) * static_cast<float>(s.spatial_scale)};
offset_rois[2] * s.spatial_scale}; array<float, 2> roi_ends = {
static_cast<float>(offset_rois[3]) * static_cast<float>(s.spatial_scale),
static_cast<float>(offset_rois[2]) * static_cast<float>(s.spatial_scale)};
array<float, 2> roi_size{}; array<float, 2> roi_size{};
array<float, 2> bin_size{}; array<float, 2> bin_size{};
...@@ -199,25 +206,25 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -199,25 +206,25 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
if constexpr(s.is_avg_pooling) if constexpr(s.is_avg_pooling)
{ {
y_t[i] = calc_pooling(offset_x, y_t[i] = static_cast<ytype>(calc_pooling(offset_x,
roi_starts, roi_starts,
bin_size, bin_size,
{ph, pw}, {ph, pw},
bin_grid_size, bin_grid_size,
in_dims, in_dims,
s.roi_offset, s.roi_offset,
avg_pool{}); avg_pool{}));
} }
else else
{ {
y_t[i] = calc_pooling(offset_x, y_t[i] = static_cast<ytype>(calc_pooling(offset_x,
roi_starts, roi_starts,
bin_size, bin_size,
{ph, pw}, {ph, pw},
bin_grid_size, bin_grid_size,
in_dims, in_dims,
s.roi_offset, s.roi_offset,
max_pool{}); max_pool{}));
} }
} }
} }
......
...@@ -33,6 +33,7 @@ template <index_int Axis, class Input, class Output> ...@@ -33,6 +33,7 @@ template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input1, Output output) __device__ void softmax(Input input1, Output output)
{ {
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input, Axis>()>; using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input, Axis>()>;
using otype = typename Output::type;
block::template run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { block::template run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto input = r.inner(op::id{})(input1); auto input = r.inner(op::id{})(input1);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX #ifdef MIGRAPHX_USE_FAST_SOFTMAX
...@@ -43,7 +44,7 @@ __device__ void softmax(Input input1, Output output) ...@@ -43,7 +44,7 @@ __device__ void softmax(Input input1, Output output)
auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input); auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = x / batch_sum; })(output, exp_in); r.inner([&](auto& y, auto x) { y = otype{x / batch_sum}; })(output, exp_in);
}); });
} }
......
...@@ -23,21 +23,23 @@ ...@@ -23,21 +23,23 @@
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/literal.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType> template <typename CType>
struct test_acosh : verify_program<test_acosh<DType>> struct test_acosh : verify_program<test_acosh<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::type_t DType = migraphx::shape::get_type<CType>();
migraphx::shape s{DType, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(1.1f); auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {1.1}});
auto max_val = mm->add_literal(100.0f); auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {100.0}});
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val);
max_val = max_val =
...@@ -48,6 +50,6 @@ struct test_acosh : verify_program<test_acosh<DType>> ...@@ -48,6 +50,6 @@ struct test_acosh : verify_program<test_acosh<DType>>
} }
}; };
template struct test_acosh<migraphx::shape::float_type>; template struct test_acosh<float>;
// template struct test_acosh<migraphx::shape::half_type>; template struct test_acosh<migraphx::half>;
// template struct test_acosh<migraphx::shape::fp8e4m3fnuz_type>; template struct test_acosh<migraphx::fp8::fp8e4m3fnuz>;
...@@ -23,21 +23,24 @@ ...@@ -23,21 +23,24 @@
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType> template <typename CType>
struct test_atanh : verify_program<test_atanh<DType>> struct test_atanh : verify_program<test_atanh<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::type_t DType = migraphx::shape::get_type<CType>();
migraphx::shape s{DType, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(-0.95f); auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {-0.95f}});
auto max_val = mm->add_literal(0.95f); auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {0.95f}});
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val);
max_val = max_val =
...@@ -48,6 +51,6 @@ struct test_atanh : verify_program<test_atanh<DType>> ...@@ -48,6 +51,6 @@ struct test_atanh : verify_program<test_atanh<DType>>
} }
}; };
template struct test_atanh<migraphx::shape::float_type>; template struct test_atanh<float>;
// template struct test_atanh<migraphx::shape::half_type>; template struct test_atanh<migraphx::half>;
// template struct test_atanh<migraphx::shape::fp8e4m3fnuz_type>; template struct test_atanh<migraphx::fp8::fp8e4m3fnuz>;
...@@ -51,4 +51,4 @@ struct test_pad : verify_program<test_pad<DType>> ...@@ -51,4 +51,4 @@ struct test_pad : verify_program<test_pad<DType>>
template struct test_pad<migraphx::shape::int32_type>; template struct test_pad<migraphx::shape::int32_type>;
template struct test_pad<migraphx::shape::float_type>; template struct test_pad<migraphx::shape::float_type>;
template struct test_pad<migraphx::shape::half_type>; template struct test_pad<migraphx::shape::half_type>;
// template struct test_pad<migraphx::shape::fp8e4m3fnuz_type>; template struct test_pad<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,13 +27,15 @@ ...@@ -27,13 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_pow : verify_program<test_pow> template <typename CType>
struct test_pow : verify_program<test_pow<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); migraphx::shape::type_t DType = migraphx::shape::get_type<CType>();
migraphx::shape s{migraphx::shape::float_type, {6}}; auto* mm = p.get_main_module();
migraphx::shape s{DType, {6}};
std::vector<float> vec_e(s.elements(), 2.0f); std::vector<float> vec_e(s.elements(), 2.0f);
auto b = mm->add_parameter("x", s); auto b = mm->add_parameter("x", s);
auto e = mm->add_literal(migraphx::literal(s, vec_e)); auto e = mm->add_literal(migraphx::literal(s, vec_e));
...@@ -41,4 +43,6 @@ struct test_pow : verify_program<test_pow> ...@@ -41,4 +43,6 @@ struct test_pow : verify_program<test_pow>
return p; return p;
} }
}; };
// TODO: add fp8 tests template struct test_pow<float>;
template struct test_pow<migraphx::half>;
template struct test_pow<migraphx::fp8::fp8e4m3fnuz>;
...@@ -59,5 +59,5 @@ struct test_roialign : verify_program<test_roialign<DType>> ...@@ -59,5 +59,5 @@ struct test_roialign : verify_program<test_roialign<DType>>
}; };
template struct test_roialign<migraphx::shape::float_type>; template struct test_roialign<migraphx::shape::float_type>;
// template struct test_roialign<migraphx::shape::half_type>; template struct test_roialign<migraphx::shape::half_type>;
// template struct test_roialign<migraphx::shape::fp8e4m3fnuz_type>; template struct test_roialign<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -48,7 +48,7 @@ template struct test_softmax<0, migraphx::shape::half_type>; ...@@ -48,7 +48,7 @@ template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>; template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>; template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>; template struct test_softmax<3, migraphx::shape::half_type>;
// template struct test_softmax<0, migraphx::shape::fp8e4m3fnuz_type>; template struct test_softmax<0, migraphx::shape::fp8e4m3fnuz_type>;
// template struct test_softmax<1, migraphx::shape::fp8e4m3fnuz_type>; template struct test_softmax<1, migraphx::shape::fp8e4m3fnuz_type>;
// template struct test_softmax<2, migraphx::shape::fp8e4m3fnuz_type>; template struct test_softmax<2, migraphx::shape::fp8e4m3fnuz_type>;
// template struct test_softmax<3, migraphx::shape::fp8e4m3fnuz_type>; template struct test_softmax<3, migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment