"src/vscode:/vscode.git/clone" did not exist on "52c5717f9704ae270acd5cba74c11eb8328eb20a"
Commit 8d7a8a6c authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents 25b33431 a09dc502
...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, ...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens; auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens; auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back(); auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = size_t num_slices =
accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{}); accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims, size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(), data_shape_lens.end(),
1, 1,
op::product{}); op::product{});
const std::size_t num_batches = const size_t num_batches =
accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{}); accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
const std::size_t data_batch_stride = const size_t data_batch_stride =
accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{}); accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const auto num_slices_per_batch = num_slices / num_batches; const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) { ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data(); const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size; const size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch; const size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims); auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0; size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx) for(size_t idx = 0; idx < num_slice_dims; ++idx)
{ {
int64_t index = slice_indices[idx]; int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx; const size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx]; const auto input_dim = data_shape_lens[input_dim_idx];
MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim)); index < static_cast<int64_t>(input_dim));
if(index < 0) if(index < 0)
index += input_dim; index += input_dim;
std::size_t size_from_slice_dims = size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1, accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims, data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size, slice_size,
......
...@@ -52,22 +52,25 @@ __device__ void generic_binary_layernorm( ...@@ -52,22 +52,25 @@ __device__ void generic_binary_layernorm(
block::template run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = vec_type<value_type>{1.0 / relements}; constexpr auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r); auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) { auto means = r.reduce(op::sum{},
auto x_out = x * relements_r; make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
// dividing x by sqrt(relements) before squaring allows computing higher values [&](auto x) {
// before overflow in low precision auto x_out = x * relements_r;
auto x2_sqrt = x * relements_rsqrt; // dividing x by sqrt(relements) before squaring allows computing
return make_array(x_out, x2_sqrt * x2_sqrt); // higher values before overflow in low precision
})(input); auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = eps; // implicit conversion for eps value_type eps_val = implicit_conversion(eps);
r.inner([&](auto& y, auto x, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x; auto m = x - mean_x;
......
...@@ -29,11 +29,15 @@ ...@@ -29,11 +29,15 @@
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/hip.hpp> #include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/float8.hpp>
namespace migraphx { namespace migraphx {
namespace math { namespace math {
constexpr float as_float(migraphx::half x) { return x; } constexpr float as_float(migraphx::half x) { return x; }
constexpr float as_float(migraphx::fp8::fp8e4m3fnuz x) { return x; }
template <class T> template <class T>
constexpr T as_float(T x) constexpr T as_float(T x)
{ {
...@@ -57,14 +61,14 @@ constexpr T as_float(T x) ...@@ -57,14 +61,14 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \ #define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \ template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \ auto __device__ name(type x, Ts... xs) -> type \
{ \ { \
return fname(x, xs...); \ return fname(x, xs...); \
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \ #define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); } inline auto __device__ name(type x, type y) -> type { return fname(x, y); }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \ #define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
...@@ -72,6 +76,12 @@ constexpr T as_float(T x) ...@@ -72,6 +76,12 @@ constexpr T as_float(T x)
auto __device__ name(migraphx::half x, Ts... xs) \ auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...)) MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \
migraphx::fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// Template with two overloads for math functions, one for half2 type and one for more generic // Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number. // <half, N> vectorization where N is 4 or another even number.
...@@ -162,6 +172,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) ...@@ -162,6 +172,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod) MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// use float to compute fp8 overload
MIGRAPHX_DEVICE_MATH_FP8(abs, ::abs)
MIGRAPHX_DEVICE_MATH_FP8(acos, ::acos)
MIGRAPHX_DEVICE_MATH_FP8(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_FP8(asin, ::asin)
MIGRAPHX_DEVICE_MATH_FP8(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_FP8(atan, ::atan)
MIGRAPHX_DEVICE_MATH_FP8(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_FP8(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_FP8(cos, ::cos)
MIGRAPHX_DEVICE_MATH_FP8(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_FP8(erf, ::erf)
MIGRAPHX_DEVICE_MATH_FP8(exp, ::exp)
MIGRAPHX_DEVICE_MATH_FP8(floor, ::floor)
MIGRAPHX_DEVICE_MATH_FP8(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_FP8(log, ::log)
MIGRAPHX_DEVICE_MATH_FP8(pow, ::pow)
MIGRAPHX_DEVICE_MATH_FP8(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_FP8(round, ::round)
MIGRAPHX_DEVICE_MATH_FP8(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH_FP8(sin, ::sin)
MIGRAPHX_DEVICE_MATH_FP8(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_FP8(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH_FP8(tan, ::tan)
MIGRAPHX_DEVICE_MATH_FP8(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_FP8(fmod, ::fmod)
// Map math functions to hip half2 functions // Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats // The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names // packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
...@@ -253,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where) ...@@ -253,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U> template <class T, class U>
constexpr auto convert(U v) constexpr auto convert(U v)
{ {
return vec_transform(v)([](auto x) -> T { return x; }); return vec_transform(v)([](auto x) -> T { return static_cast<T>(x); });
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp> #include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx { namespace migraphx {
...@@ -53,9 +54,9 @@ __device__ void pad(const index& idx, ...@@ -53,9 +54,9 @@ __device__ void pad(const index& idx,
if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) { if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j]; return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
})) }))
output[multi] = pad_val; output[multi] = implicit_conversion(pad_val);
else else
output[multi] = input[input_idx]; output[multi] = implicit_conversion(input[input_idx]);
}); });
} }
......
...@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op)
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in); out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out); in = op(in, out);
#if __AMDGCN_WAVEFRONT_SIZE == 64 #if __AMDGCN_WAVEFRONT_SIZE == 32
out = dpp_swizzle<0x1e0>(in);
in = op(in, out);
#else
out = dpp_mov<dpp_row_bcast(15), 0xa>(in); out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in); out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
...@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op)
} }
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK) #if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1 #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
(void)f; \
x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64 #elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
...@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \ "s_nop 1\n" \
: "=v"(x) \ : "=v"(x) \
: "0"(x)) : "0"(x)); \
(void)f
#else #else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \ : "=v"(x) \
: "0"(x)) : "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif #endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \ #define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \ __device__ inline void dpp_reduce(double& x, op f) \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \ { \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); \
__device__ inline void dpp_reduce(int32_t& x, op) \ } \
{ \ __device__ inline void dpp_reduce(float& x, op f) \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \ { \
} \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); } } \
__device__ inline void dpp_reduce(half& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); \
} \
__device__ inline void dpp_reduce(int32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); \
}
// Note: when max and min are in int32_t, signed version of instruction needs to be used. // Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u) MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u)
...@@ -99,14 +117,10 @@ template <class Op, class T, class Index, class F> ...@@ -99,14 +117,10 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32 constexpr index_int lanes_per_thread = __AMDGCN_WAVEFRONT_SIZE;
constexpr index_int lanes_per_thread = 16;
#else
constexpr index_int lanes_per_thread = 64;
#endif
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = type(init);
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op); dpp_reduce(x, op);
...@@ -117,7 +131,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -117,7 +131,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
} }
__syncthreads(); __syncthreads();
type y = init; type y = type(init);
for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++) for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++)
{ {
y = op(y, buffer[i]); y = op(y, buffer[i]);
...@@ -244,9 +258,8 @@ struct reducer_base ...@@ -244,9 +258,8 @@ struct reducer_base
{ {
auto&& derived = static_cast<const Derived&>(*this); auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x); auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& { return make_storage_access<typename decltype(t)::type>(
return t[i]; [=](auto i, auto...) -> auto& { return t[i]; });
});
} }
} }
...@@ -393,7 +406,7 @@ struct block ...@@ -393,7 +406,7 @@ struct block
{ {
using max_iterations = decltype(idx.max_local_stride_iterations(n)); using max_iterations = decltype(idx.max_local_stride_iterations(n));
inner_storage<R, max_iterations{}, N> storage; inner_storage<R, max_iterations{}, N> storage;
idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); }); idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = R{f(xs(j, d)...)}; });
return storage; return storage;
} }
}; };
...@@ -482,7 +495,7 @@ struct lane ...@@ -482,7 +495,7 @@ struct lane
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const __device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{ {
using type = remove_reference_t<decltype(x(0, _c<0>))>; using type = remove_reference_t<decltype(x(0, _c<0>))>;
type r = init; type r = type(init);
for(index_int j = 0; j < n; j++) for(index_int j = 0; j < n; j++)
{ {
r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...)); r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...));
......
...@@ -62,7 +62,7 @@ struct avg_pool ...@@ -62,7 +62,7 @@ struct avg_pool
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{ {
return (y == 0) ? 0.0 : (x / y); return (y == 0) ? T{0.0} : T{x / y};
} }
}; };
...@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
{ {
if(xy[ii] < -1.0f or xy[ii] > dims[ii]) if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{ {
return 0; return implicit_conversion(0);
} }
xy[ii] = migraphx::max(xy[ii], 0.0f); xy[ii] = migraphx::max(xy[ii], 0.0f);
...@@ -92,15 +92,16 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -92,15 +92,16 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
high[0] * dims[1] + low[1], high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]}; high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0]; float ly = xy[0] - low[0];
float lx = xy[1] - low[1]; float lx = xy[1] - low[1];
float hy = 1.0f - ly; float hy = 1.0f - ly;
float hx = 1.0f - lx; float hx = 1.0f - lx;
array<typename Iterator::value_type, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx}; // do calculations in floating point and convert final result to required type
array<float, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]); auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]); auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23); return implicit_conversion(pooling(v01, v23));
} }
template <class Iterator, class Op> template <class Iterator, class Op>
...@@ -113,8 +114,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data, ...@@ -113,8 +114,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
float roi_offset, float roi_offset,
Op op) Op op)
{ {
typename Iterator::value_type output_val = op.init(); using in_dtype = typename Iterator::value_type;
const int64_t count = bin_grid_size[0] * bin_grid_size[1]; in_dtype output_val = in_dtype{op.init()};
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) { dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<index_int, 2> id = {iy, ix}; array<index_int, 2> id = {iy, ix};
array<float, 2> locs = array<float, 2> locs =
...@@ -148,7 +150,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -148,7 +150,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto x = x_t.begin(); const auto x = x_t.begin();
const auto rois = rois_t.begin(); const auto rois = rois_t.begin();
const auto ind = ind_t.begin(); const auto ind = ind_t.begin();
// input shape // input shape
auto x_lens = x_t.get_shape().lens; auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1]; auto channel_num = x_lens[1];
...@@ -176,10 +177,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -176,10 +177,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto offset_rois = rois + (n * roi_column_num); const auto offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n]; const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale, array<float, 2> roi_starts = {
offset_rois[0] * s.spatial_scale}; static_cast<float>(offset_rois[1]) * static_cast<float>(s.spatial_scale),
array<float, 2> roi_ends = {offset_rois[3] * s.spatial_scale, static_cast<float>(offset_rois[0]) * static_cast<float>(s.spatial_scale)};
offset_rois[2] * s.spatial_scale}; array<float, 2> roi_ends = {
static_cast<float>(offset_rois[3]) * static_cast<float>(s.spatial_scale),
static_cast<float>(offset_rois[2]) * static_cast<float>(s.spatial_scale)};
array<float, 2> roi_size{}; array<float, 2> roi_size{};
array<float, 2> bin_size{}; array<float, 2> bin_size{};
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,42 +21,63 @@ ...@@ -21,42 +21,63 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_GATHER_HPP #ifndef MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP #define MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#include <migraphx/argument.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context; struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicAdd(&x, y);
}
};
struct hip_gather struct assign_mul
{ {
op::gather op; template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
T old = x;
T assumed;
do
{
assumed = old;
old = atomicCAS(&x, assumed, assumed * y);
} while(assumed != old);
}
};
template <class Self, class F> struct assign_max
static auto reflect(Self& self, F f) {
template <typename T, typename U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{ {
return migraphx::reflect(self.op, f); atomicMax(&x, y);
} }
};
std::string name() const { return "gpu::gather"; } struct assign_min
shape compute_shape(std::vector<shape> inputs) const; {
argument template <typename T, typename U>
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; atomicMin(&x, y);
} }
}; };
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -26,36 +26,10 @@ ...@@ -26,36 +26,10 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/scatter_reduction_modes.hpp>
namespace migraphx { namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x += y;
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x *= y;
}
};
template <class T, class U, class V, class F> template <class T, class U, class V, class F>
__device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f) __device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f)
{ {
......
...@@ -43,7 +43,7 @@ __device__ void softmax(Input input1, Output output) ...@@ -43,7 +43,7 @@ __device__ void softmax(Input input1, Output output)
auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input); auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = x / batch_sum; })(output, exp_in); r.inner([&](auto& y, auto x) { y = implicit_conversion(x / batch_sum); })(output, exp_in);
}); });
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/kernels/shape.hpp> #include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp> #include <migraphx/kernels/iota_iterator.hpp>
#include <migraphx/kernels/float8.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -251,7 +251,7 @@ constexpr T numeric_max() ...@@ -251,7 +251,7 @@ constexpr T numeric_max()
} }
template <class T> template <class T>
constexpr T numeric_lowest() constexpr auto numeric_lowest() -> decltype(numeric_max<T>())
{ {
if constexpr(is_integral<T>{}) if constexpr(is_integral<T>{})
{ {
......
...@@ -207,7 +207,7 @@ struct implicit_conversion_op ...@@ -207,7 +207,7 @@ struct implicit_conversion_op
template <class U> template <class U>
constexpr operator U() const constexpr operator U() const
{ {
return x; return static_cast<U>(x);
} }
}; };
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Support.h> #include <mlir-c/Support.h>
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 4
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck // Only undefine when not using cppcheck
#ifndef CPPCHECK #ifndef CPPCHECK
...@@ -73,6 +73,7 @@ namespace gpu { ...@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_LIMIT);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
...@@ -318,31 +319,30 @@ struct mlir_program ...@@ -318,31 +319,30 @@ struct mlir_program
return result; return result;
} }
MlirType make_tensor(const shape& s) const MlirType make_mlir_shaped(const shape& s) const
{ {
if(not s.standard())
MIGRAPHX_THROW("MLIR expects all tensors to be in standard shape");
if(s.dynamic()) if(s.dynamic())
MIGRAPHX_THROW("MLIR does not support dynamic shapes"); MIGRAPHX_THROW("MLIR does not support dynamic shapes");
std::vector<int64_t> lens(s.lens().begin(), s.lens().end()); std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
return mlirRankedTensorTypeGet( std::vector<int64_t> strides(s.strides().begin(), s.strides().end());
lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull()); return rocmlirMIXRShapedTypeGet(
lens.size(), lens.data(), strides.data(), make_type(s.type()));
} }
template <class Range> template <class Range>
std::vector<MlirType> make_tensors(const Range& r) std::vector<MlirType> make_mlir_shapeds(const Range& r)
{ {
std::vector<MlirType> result; std::vector<MlirType> result;
std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) { std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) {
return make_tensor(s); return make_mlir_shaped(s);
}); });
return result; return result;
} }
MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs) MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs)
{ {
auto in = make_tensors(inputs); auto in = make_mlir_shapeds(inputs);
auto out = make_tensors(outputs); auto out = make_mlir_shapeds(outputs);
return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data()); return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data());
} }
...@@ -504,11 +504,7 @@ struct mlir_program ...@@ -504,11 +504,7 @@ struct mlir_program
mlir_operation_state& add_results(const std::vector<shape>& outputs) mlir_operation_state& add_results(const std::vector<shape>& outputs)
{ {
std::vector<shape> reshaped(outputs.size()); auto x = prog->make_mlir_shapeds(outputs);
std::transform(outputs.begin(), outputs.end(), reshaped.begin(), [](const shape& r) {
return shape{r.type(), r.lens()};
});
auto x = prog->make_tensors(reshaped);
if(not x.empty()) if(not x.empty())
{ {
mlirOperationStateAddResults(&op_state, x.size(), x.data()); mlirOperationStateAddResults(&op_state, x.size(), x.data());
...@@ -581,7 +577,7 @@ struct mlir_program ...@@ -581,7 +577,7 @@ struct mlir_program
std::vector<shape> outputs = m.get_output_shapes(); std::vector<shape> outputs = m.get_output_shapes();
std::vector<MlirLocation> arg_locs(inputs.size(), location); std::vector<MlirLocation> arg_locs(inputs.size(), location);
auto body_inputs = make_tensors(inputs); auto body_inputs = make_mlir_shapeds(inputs);
mlir_region region = mlirRegionCreate(); mlir_region region = mlirRegionCreate();
mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data()); mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data());
MlirBlock result = fbody.get(); MlirBlock result = fbody.get();
...@@ -607,7 +603,7 @@ struct mlir_program ...@@ -607,7 +603,7 @@ struct mlir_program
return "func.return"; return "func.return";
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
return "tosa.const"; return "migraphx.literal";
} }
return "migraphx." + ins->name(); return "migraphx." + ins->name();
} }
...@@ -666,7 +662,8 @@ struct mlir_program ...@@ -666,7 +662,8 @@ struct mlir_program
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
literal r = ins->get_literal(); literal r = ins->get_literal();
MlirType tensor_type = make_tensor(ins->get_shape()); MlirType shaped_type = make_mlir_shaped(ins->get_shape());
MlirType tensor_type = rocmlirMIXRShapedTypeAsTensor(shaped_type);
MlirAttribute mlir_value_attr = MlirAttribute mlir_value_attr =
mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data()); mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data());
ops.add_attributes({{"value", mlir_value_attr}}); ops.add_attributes({{"value", mlir_value_attr}});
...@@ -796,7 +793,9 @@ struct mlir_program ...@@ -796,7 +793,9 @@ struct mlir_program
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{})) if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive; tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)}; mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
for(auto i : range(mlirRockTuningGetNumParams(params.get()))) const auto limit =
value_of(MIGRAPHX_MLIR_TUNE_LIMIT{}, std::numeric_limits<std::size_t>::max());
for(auto i : range(std::min<std::size_t>(limit, mlirRockTuningGetNumParams(params.get()))))
{ {
mlir_tuning_param param{mlirRockTuningParamCreate()}; mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get())) if(not mlirRockTuningParamGet(params.get(), i, param.get()))
...@@ -942,35 +941,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs) ...@@ -942,35 +941,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
auto param = m.get_parameter(name); auto param = m.get_parameter(name);
if(input.standard()) if(input.standard())
continue; continue;
auto lens = input.lens(); auto new_param = m.add_parameter(name + ".0", input);
auto strides = input.strides();
std::vector<operation> ops;
if(input.transposed())
{
auto perm = find_permutation(input);
auto iperm = invert_permutation(perm);
lens = reorder_dims(lens, iperm);
strides = reorder_dims(strides, iperm);
ops.push_back(make_op("transpose", {{"permutation", perm}}));
}
if(input.broadcasted())
{
std::transform(lens.begin(),
lens.end(),
strides.begin(),
lens.begin(),
[](auto len, auto stride) -> std::size_t {
if(stride == 0)
return 1;
return len;
});
ops.push_back(make_op("multibroadcast", {{"out_lens", input.lens()}}));
}
auto new_param =
std::accumulate(ops.begin(),
ops.end(),
m.add_parameter(name + ".0", shape{input.type(), lens}),
[&](auto x, auto op) { return m.insert_instruction(param, op, x); });
m.replace_instruction(param, new_param); m.replace_instruction(param, new_param);
m.remove_instruction(param); m.remove_instruction(param);
} }
...@@ -1032,6 +1003,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, ...@@ -1032,6 +1003,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
mlir_program mp; mlir_program mp;
mp.set_gpu_properties(migraphx_ctx); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
return mp.get_tuning_config(exhaustive); return mp.get_tuning_config(exhaustive);
} }
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL #ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp> #include <migraphx/gpu/ck.hpp>
#endif #endif
#include <migraphx/gpu/fuse_mlir.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -124,34 +125,55 @@ struct find_add_layernorm ...@@ -124,34 +125,55 @@ struct find_add_layernorm
} }
}; };
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
struct pre_gemm_softmax_gemm : gemm_softmax_gemm struct pre_gemm_softmax_gemm : gemm_softmax_gemm
{ {
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; } std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
}; };
MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm); MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm);
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) auto is_ck_gemm()
{ {
if(ins->name() != "dot") return match::make_basic_pred_matcher([=](instruction_ref ins) {
return false; #ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type())) if(not enabled(MIGRAPHX_ENABLE_CK{}))
return false;
if(ins->name() != "dot")
return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
return true;
#else
(void)ins;
return false; return false;
return true; #endif
});
}
auto is_mlir_gemm()
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(not mlir_attention_enabled())
return false;
if(ins->name() != "dot")
return false;
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return pre_gemm_softmax_gemm::is_mlir_supported_type(i->get_shape().type());
});
});
} }
struct find_gemm_softmax_gemm struct find_gemm_softmax_gemm
{ {
auto matcher() const auto matcher() const
{ {
auto gemm1 = auto gemm1 = match::skip(match::name("contiguous"))(
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1")));
auto mul = match::name("mul")( auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax"); auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax)); return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))(
match::arg(0)(softmax));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...@@ -179,8 +201,6 @@ struct find_gemm_softmax_gemm ...@@ -179,8 +201,6 @@ struct find_gemm_softmax_gemm
} }
}; };
#endif
} // namespace } // namespace
void prefuse_ops::apply(module_pass_manager& mpm) const void prefuse_ops::apply(module_pass_manager& mpm) const
...@@ -188,10 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const ...@@ -188,10 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{}); match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{}); match::find_matches(mpm.get_module(), find_add_layernorm{});
#ifdef MIHRAPHX_USE_COMPOSABLEKERNEL match::find_matches(mpm, find_gemm_softmax_gemm{});
if(enabled(MIGRAPHX_ENABLE_CK{}))
match::find_matches(mpm, find_gemm_softmax_gemm{});
#endif
} }
} // namespace gpu } // namespace gpu
......
...@@ -53,6 +53,16 @@ bool get_compute_fp32_flag() ...@@ -53,6 +53,16 @@ bool get_compute_fp32_flag()
return (starts_with(device_name, "gfx9") and device_name >= "gfx908"); return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
} }
bool rocblas_fp8_available()
{
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false;
#else
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
#endif
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -98,12 +98,28 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -98,12 +98,28 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
ctx.set_exhaustive_tune_flag(options.exhaustive_tune); ctx.set_exhaustive_tune_flag(options.exhaustive_tune);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end()); std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::float_type);
unsupported_types.erase(shape::type_t::fp8e4m3fnuz_type);
unsupported_types.erase(shape::type_t::half_type); unsupported_types.erase(shape::type_t::half_type);
unsupported_types.erase(shape::type_t::bool_type); unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type); unsupported_types.erase(shape::type_t::int8_type);
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type); unsupported_types.erase(shape::type_t::tuple_type);
std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
}
// add all device kernels
unsupported_fp8_ops.insert("logsoftmax");
unsupported_fp8_ops.insert("nonzero");
unsupported_fp8_ops.insert("prefix_scan_sum");
unsupported_fp8_ops.insert("scatter_none");
unsupported_fp8_ops.insert("topk");
unsupported_fp8_ops.insert("rnn_var_sl_shift_output");
unsupported_fp8_ops.insert("multinomial");
unsupported_fp8_ops.insert("argmax");
unsupported_fp8_ops.insert("argmin");
// clang-format off // clang-format off
return return
{ {
...@@ -135,6 +151,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -135,6 +151,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops},
dead_code_elimination{},
optimize_module{}, optimize_module{},
fuse_pointwise{}, fuse_pointwise{},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails) ...@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails)
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh")); auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh}); mm->add_return({tanh});
} }
migraphx::program p2(p1); // This pass should not fuse as int32_t tanh isn't supported.
// This pass should do nothing as int32_t tanh isn't supported.
run_pass(p1); run_pass(p1);
EXPECT(p1 == p2); auto* mm = p1.get_main_module();
bool has_pointwise =
std::any_of(mm->begin(), mm->end(), [&](const auto& i) { return i.name() == "pointwise"; });
EXPECT(has_pointwise);
} }
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
......
...@@ -139,7 +139,8 @@ const std::string math_template = R"__migraphx__( ...@@ -139,7 +139,8 @@ const std::string math_template = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp> #include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/math.hpp> #include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
using namespace migraphx;
namespace migraphx {
extern "C" { extern "C" {
__global__ void kernel(${type}* p) __global__ void kernel(${type}* p)
{ {
...@@ -148,6 +149,7 @@ __global__ void kernel(${type}* p) ...@@ -148,6 +149,7 @@ __global__ void kernel(${type}* p)
} }
} }
}
int main() {} int main() {}
...@@ -348,18 +350,19 @@ TEST_CASE(compile_math) ...@@ -348,18 +350,19 @@ TEST_CASE(compile_math)
auto vec_sizes = {2, 4, 6}; auto vec_sizes = {2, 4, 6};
for(auto&& t : migraphx::shape::types()) for(auto&& t : migraphx::shape::types())
{ {
if(contains({migraphx::shape::bool_type, if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::tuple_type},
t))
continue; continue;
auto name = migraphx::shape::cpp_type(t); auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type) if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::"); name.insert(0, "migraphx::");
data_types.push_back(name); data_types.push_back(name);
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) { // fp8 doesn't have vectorization support yet, therefore skip it for now.
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">"; if(t != migraphx::shape::fp8e4m3fnuz_type)
}); {
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
});
}
} }
migraphx::shape input{migraphx::shape::float_type, {5, 2}}; migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::hip_compile_options options; migraphx::gpu::hip_compile_options options;
...@@ -429,7 +432,6 @@ TEST_CASE(assert_type_min_max) ...@@ -429,7 +432,6 @@ TEST_CASE(assert_type_min_max)
min = std::to_string(as.min()); min = std::to_string(as.min());
max = std::to_string(as.max()); max = std::to_string(as.max());
} }
auto src = migraphx::interpolate_string(assert_template, auto src = migraphx::interpolate_string(assert_template,
{{"type", name}, {"max", max}, {"min", min}}); {{"type", name}, {"max", max}, {"min", min}});
migraphx::shape input{migraphx::shape::float_type, {5, 2}}; migraphx::shape input{migraphx::shape::float_type, {5, 2}};
......
...@@ -141,9 +141,9 @@ TEST_CASE(conv) ...@@ -141,9 +141,9 @@ TEST_CASE(conv)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_convolution(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
return %0 : tensor<1x2x2x2xf32> return %0 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -160,15 +160,38 @@ module { ...@@ -160,15 +160,38 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(conv_nhwc)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x1x32x8>, <2x8x3x3xf32, 72x1x24x8> -> <1x2x2x2xf32, 8x1x4x2>
return %0 : !migraphx.shaped<1x2x2x2xf32, 8x1x4x2>
}
}
)__migraphx__";
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}, {128, 1, 32, 8}});
auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}, {72, 1, 24, 8}});
auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w);
m.add_return({conv});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(conv_add_relu) TEST_CASE(conv_add_relu)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_convolution_add_relu(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_convolution_add_relu(%arg0: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution %arg2, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %1 = migraphx.add %0, %arg0 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %2 = migraphx.relu %1 : <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
return %2 : tensor<1x2x2x2xf32> return %2 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add) ...@@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_quant_dot_add(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32> %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xi8, 20x4x1>, <1x4x3xi8, 12x3x1> -> <1x5x3xi32, 15x3x1>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32> %1 = migraphx.add %0, %arg2 : <1x5x3xi32, 15x3x1>, <1x5x3xi32, 15x3x1> -> <1x5x3xi32, 15x3x1>
return %1 : tensor<1x5x3xi32> return %1 : !migraphx.shaped<1x5x3xi32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -219,10 +242,10 @@ TEST_CASE(dot_add) ...@@ -219,10 +242,10 @@ TEST_CASE(dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_add(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.add %0, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : tensor<1x5x3xf32> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize) ...@@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor<1x2x2x2xf32>, %arg3: tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xi8>, tensor<2x8x3x3xi8>) -> tensor<1x2x2x2xi32> %0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> -> <1x2x2x2xi32, 8x4x2x1>
%1 = migraphx.dequantizelinear(%0, %arg2, %arg3) : (tensor<1x2x2x2xi32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xf32> %1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.quantizelinear(%1, %arg2, %arg3) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> %2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xi32, 8x4x2x1>
return %2 : tensor<1x2x2x2xi32> return %2 : !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -278,10 +301,10 @@ TEST_CASE(dot_convert) ...@@ -278,10 +301,10 @@ TEST_CASE(dot_convert)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_convert(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16> %1 = migraphx.convert %0 {target_type = 1 : i64} : <1x5x3xf32, 15x3x1> to <1x5x3xf16, 15x3x1>
return %1 : tensor<1x5x3xf16> return %1 : !migraphx.shaped<1x5x3xf16, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -304,10 +327,10 @@ TEST_CASE(dot_where) ...@@ -304,10 +327,10 @@ TEST_CASE(dot_where)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_where(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : tensor<1x5x3xf32> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
......
a5537f2f563d4975c7e6121a7eb260bbbfd9455a d69842226b47e5336568103541b071447caeb9bf
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment