Unverified Commit c4cee345 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Merge branch 'develop' into rocblas_fp8

parents c40a39c3 eafd55de
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_JIT_SCATTER_HPP
#define MIGRAPHX_GUARD_JIT_SCATTER_HPP
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <typename Derived>
struct scatter_compiler : compiler<Derived>
{
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
const auto inputs =
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()});
hip_compile_options options;
options.set_launch_params(op.to_value(), compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = derived().get_kernel_name(op);
options.virtual_inputs = inputs;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options.params += "-Wno-float-equal";
const auto src = derived().make_interpolated_string(op);
return prepend_copy_data_to_output(compile_hip_code_object(src, options));
}
compiler_replace prepend_copy_data_to_output(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
std::string get_kernel_name(const operation& op) const { return op.name() + "_kernel"; }
const Derived& derived() const { return static_cast<const Derived&>(*this); }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -21,11 +21,7 @@ ...@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/compiler.hpp> #include "scatter.hpp"
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void* ...@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__"; )__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler> struct scatternd_compiler : scatter_compiler<scatternd_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const
{ {
return {"scatternd_none", "scatternd_add", "scatternd_mul"}; return {
"scatternd_none", "scatternd_add", "scatternd_mul", "scatternd_min", "scatternd_max"};
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const std::string make_interpolated_string(const operation& op) const
{ {
hip_compile_options options; const auto reduction = op.name().substr(std::char_traits<char>::length("scatternd_"));
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements())); return interpolate_string(scatternd_kernel, {{"reduction", "assign_" + reduction}});
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const std::string get_kernel_name(const operation&) const { return "scatternd_kernel"; }
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(
ctx,
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -22,8 +22,13 @@ ...@@ -22,8 +22,13 @@
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx { namespace migraphx {
template <typename To, typename From> template <typename To,
typename From,
MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept inline constexpr To bit_cast(From fr) noexcept
{ {
static_assert(sizeof(To) == sizeof(From)); static_assert(sizeof(To) == sizeof(From));
......
...@@ -365,15 +365,6 @@ struct float8 ...@@ -365,15 +365,6 @@ struct float8
inline __device__ constexpr float8& operator=(const float8& rhs) = default; inline __device__ constexpr float8& operator=(const float8& rhs) = default;
inline __device__ constexpr float8& operator=(float8&& rhs) noexcept = default; inline __device__ constexpr float8& operator=(float8&& rhs) noexcept = default;
inline __device__ constexpr bool operator==(const float8& rhs) const
{
if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false;
else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data))
return true;
return false;
}
inline __device__ constexpr bool operator<(const float8& rhs) const inline __device__ constexpr bool operator<(const float8& rhs) const
{ {
const auto we = static_cast<float>(*this); const auto we = static_cast<float>(*this);
...@@ -403,12 +394,21 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>; ...@@ -403,12 +394,21 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_FP8_FABS(T) \
inline constexpr __device__ T fabs(T v) \ #define MIGRAPHX_FP8_OTHER_OPS(T) \
{ \ inline constexpr __device__ T fabs(T v) \
/*NOLINTNEXTLINE*/ \ { \
v.data = v.data & 0x7f; \ /*NOLINTNEXTLINE*/ \
return v; \ v.data = v.data & 0x7f; \
return v; \
} \
inline __device__ constexpr bool operator==(const T& lhs, const T& rhs) \
{ \
if(rhs.is_nan() or rhs.is_inf() or lhs.is_nan() or lhs.is_inf()) \
return false; \
else if((rhs.is_zero() and lhs.is_zero()) or (lhs.data == rhs.data)) \
return true; \
return false; \
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
...@@ -417,11 +417,10 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>; ...@@ -417,11 +417,10 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
MIGRAPHX_FP8_BINARY_OP(-, T, T) \ MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \ MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \ MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \ MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \ MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \ MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \
MIGRAPHX_FP8_FABS(T) MIGRAPHX_FP8_OTHER_OPS(T)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2) MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2fnuz) MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2fnuz)
...@@ -447,7 +446,7 @@ class numeric_limits<fp8e4m3fnuz> ...@@ -447,7 +446,7 @@ class numeric_limits<fp8e4m3fnuz>
{ {
return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits());
} }
// this is min value that is not DeNorm. DeNorm min is 0x01 // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01
static constexpr __device__ fp8e4m3fnuz min() static constexpr __device__ fp8e4m3fnuz min()
{ {
return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits());
...@@ -475,7 +474,7 @@ class numeric_limits<fp8e4m3fn> ...@@ -475,7 +474,7 @@ class numeric_limits<fp8e4m3fn>
} }
static constexpr __device__ fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); } static constexpr __device__ fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01 // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01
static constexpr __device__ fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); } static constexpr __device__ fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); }
static constexpr __device__ fp8e4m3fn lowest() static constexpr __device__ fp8e4m3fn lowest()
...@@ -503,8 +502,10 @@ class numeric_limits<fp8e5m2fnuz> ...@@ -503,8 +502,10 @@ class numeric_limits<fp8e5m2fnuz>
{ {
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
} }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times. // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static constexpr __device__ fp8e5m2fnuz min() static constexpr __device__ fp8e5m2fnuz min()
{ {
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
...@@ -529,8 +530,7 @@ class numeric_limits<fp8e5m2> ...@@ -529,8 +530,7 @@ class numeric_limits<fp8e5m2>
} }
static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
...@@ -540,23 +540,27 @@ class numeric_limits<fp8e5m2> ...@@ -540,23 +540,27 @@ class numeric_limits<fp8e5m2>
} // namespace fp8 } // namespace fp8
// NOLINTNEXTLINE template <class T,
#define MIGRAPHX_FP8_MIN_MAX(T) \ MIGRAPHX_REQUIRES(is_same<T, fp8::fp8e4m3fnuz>{} or is_same<T, fp8::fp8e5m2fnuz>{} or
template <> \ is_same<T, fp8::fp8e4m3fn>{} or is_same<T, fp8::fp8e5m2>{})>
constexpr T numeric_max<T, void>() \ constexpr T numeric_max(migraphx::fp8::f8_type unused = migraphx::fp8::f8_type::fp8)
{ \ {
return fp8::numeric_limits<T>::max(); \ // unused parameter is added to make this numeric_max different overload definition
} \ // compared to numeric_max defined in type_traits.hpp
template <> \ (void)(unused);
constexpr T numeric_lowest<T>() \ return fp8::numeric_limits<T>::max();
{ \ }
return fp8::numeric_limits<T>::lowest(); \ template <class T,
} MIGRAPHX_REQUIRES(is_same<T, fp8::fp8e4m3fnuz>{} or is_same<T, fp8::fp8e5m2fnuz>{} or
is_same<T, fp8::fp8e4m3fn>{} or is_same<T, fp8::fp8e5m2>{})>
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fnuz); constexpr T numeric_lowest(migraphx::fp8::f8_type unused = migraphx::fp8::f8_type::fp8)
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2fnuz); {
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fn); // unused parameter is added to make this numeric_lowest different overload definition
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2); // compared to numeric_lowest defined in type_traits.hpp
(void)(unused);
return fp8::numeric_limits<T>::lowest();
}
} // namespace migraphx } // namespace migraphx
// ================================================================================================= // =================================================================================================
#if defined(__clang__) #if defined(__clang__)
......
...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, ...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens; auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens; auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back(); auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = size_t num_slices =
accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{}); accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims, size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(), data_shape_lens.end(),
1, 1,
op::product{}); op::product{});
const std::size_t num_batches = const size_t num_batches =
accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{}); accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
const std::size_t data_batch_stride = const size_t data_batch_stride =
accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{}); accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const auto num_slices_per_batch = num_slices / num_batches; const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) { ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data(); const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size; const size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch; const size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims); auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0; size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx) for(size_t idx = 0; idx < num_slice_dims; ++idx)
{ {
int64_t index = slice_indices[idx]; int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx; const size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx]; const auto input_dim = data_shape_lens[input_dim_idx];
MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim)); index < static_cast<int64_t>(input_dim));
if(index < 0) if(index < 0)
index += input_dim; index += input_dim;
std::size_t size_from_slice_dims = size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1, accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims, data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size, slice_size,
......
...@@ -54,12 +54,12 @@ __device__ void generic_binary_layernorm( ...@@ -54,12 +54,12 @@ __device__ void generic_binary_layernorm(
using value_type = typename Input1::type; using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>; using vec_value_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = static_cast<vec_value_type>(1.0 / relements);
constexpr auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r); auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, auto means = r.reduce(op::sum{},
make_array<vec_value_type>(static_cast<vec_value_type>(0), make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
static_cast<vec_value_type>(0)),
[&](auto x) { [&](auto x) {
auto x_out = x * relements_r; auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing // dividing x by sqrt(relements) before squaring allows computing
...@@ -71,7 +71,7 @@ __device__ void generic_binary_layernorm( ...@@ -71,7 +71,7 @@ __device__ void generic_binary_layernorm(
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = static_cast<value_type>(eps); value_type eps_val = implicit_conversion(eps);
r.inner([&](auto& y, auto x, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x; auto m = x - mean_x;
......
...@@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where) ...@@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U> template <class T, class U>
constexpr auto convert(U v) constexpr auto convert(U v)
{ {
return vec_transform(v)([](auto x) { return static_cast<T>(x); }); return vec_transform(v)([](auto x) -> T { return static_cast<T>(x); });
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp> #include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx { namespace migraphx {
...@@ -54,9 +55,9 @@ __device__ void pad(const index& idx, ...@@ -54,9 +55,9 @@ __device__ void pad(const index& idx,
if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) { if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j]; return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
})) }))
output[multi] = otype(pad_val); output[multi] = implicit_conversion(pad_val);
else else
output[multi] = otype(input[input_idx]); output[multi] = implicit_conversion(input[input_idx]);
}); });
} }
......
...@@ -62,7 +62,7 @@ struct avg_pool ...@@ -62,7 +62,7 @@ struct avg_pool
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{ {
return (y == 0) ? static_cast<T>(0.0) : static_cast<T>(x / y); return (y == 0) ? T{0.0} : T{x / y};
} }
}; };
...@@ -77,7 +77,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -77,7 +77,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
{ {
if(xy[ii] < -1.0f or xy[ii] > dims[ii]) if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{ {
return static_cast<ret_type>(0); return implicit_conversion(0);
} }
xy[ii] = migraphx::max(xy[ii], 0.0f); xy[ii] = migraphx::max(xy[ii], 0.0f);
...@@ -93,18 +93,17 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -93,18 +93,17 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
high[0] * dims[1] + low[1], high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]}; high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0];
float lx = xy[1] - low[1]; float ly = xy[0] - low[0];
float hy = 1.0f - ly; float lx = xy[1] - low[1];
float hx = 1.0f - lx; float hy = 1.0f - ly;
array<ret_type, 4> ws = {static_cast<ret_type>(hy * hx), float hx = 1.0f - lx;
static_cast<ret_type>(hy * lx), // do calculations in floating point and convert final result to required type
static_cast<ret_type>(ly * hx), array<float, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
static_cast<ret_type>(ly * lx)};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]); auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]); auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23); return implicit_conversion(pooling(v01, v23));
} }
template <class Iterator, class Op> template <class Iterator, class Op>
...@@ -153,7 +152,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -153,7 +152,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto x = x_t.begin(); const auto x = x_t.begin();
const auto rois = rois_t.begin(); const auto rois = rois_t.begin();
const auto ind = ind_t.begin(); const auto ind = ind_t.begin();
using ytype = typename W::type;
// input shape // input shape
auto x_lens = x_t.get_shape().lens; auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1]; auto channel_num = x_lens[1];
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#include <migraphx/kernels/types.hpp>
namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicAdd(&x, y);
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
T old = x;
T assumed;
do
{
assumed = old;
old = atomicCAS(&x, assumed, assumed * y);
} while(assumed != old);
}
};
struct assign_max
{
template <typename T, typename U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicMax(&x, y);
}
};
struct assign_min
{
template <typename T, typename U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicMin(&x, y);
}
};
} // namespace migraphx
#endif
...@@ -26,36 +26,10 @@ ...@@ -26,36 +26,10 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/scatter_reduction_modes.hpp>
namespace migraphx { namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x += y;
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x *= y;
}
};
template <class T, class U, class V, class F> template <class T, class U, class V, class F>
__device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f) __device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f)
{ {
......
...@@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output) ...@@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output)
auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input); auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = static_cast<otype>(x / batch_sum); })(output, exp_in); r.inner([&](auto& y, auto x) { y = implicit_conversion(x / batch_sum); })(output, exp_in);
}); });
} }
......
...@@ -251,7 +251,7 @@ constexpr T numeric_max() ...@@ -251,7 +251,7 @@ constexpr T numeric_max()
} }
template <class T> template <class T>
constexpr T numeric_lowest() constexpr auto numeric_lowest() -> decltype(numeric_max<T>())
{ {
if constexpr(is_integral<T>{}) if constexpr(is_integral<T>{})
{ {
......
...@@ -207,7 +207,7 @@ struct implicit_conversion_op ...@@ -207,7 +207,7 @@ struct implicit_conversion_op
template <class U> template <class U>
constexpr operator U() const constexpr operator U() const
{ {
return x; return static_cast<U>(x);
} }
}; };
......
...@@ -73,6 +73,7 @@ namespace gpu { ...@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_LIMIT);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
...@@ -796,7 +797,9 @@ struct mlir_program ...@@ -796,7 +797,9 @@ struct mlir_program
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{})) if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive; tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)}; mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
for(auto i : range(mlirRockTuningGetNumParams(params.get()))) const auto limit =
value_of(MIGRAPHX_MLIR_TUNE_LIMIT{}, std::numeric_limits<std::size_t>::max());
for(auto i : range(std::min<std::size_t>(limit, mlirRockTuningGetNumParams(params.get()))))
{ {
mlir_tuning_param param{mlirRockTuningParamCreate()}; mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get())) if(not mlirRockTuningParamGet(params.get(), i, param.get()))
...@@ -1032,6 +1035,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, ...@@ -1032,6 +1035,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
mlir_program mp; mlir_program mp;
mp.set_gpu_properties(migraphx_ctx); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
return mp.get_tuning_config(exhaustive); return mp.get_tuning_config(exhaustive);
} }
......
...@@ -28,7 +28,10 @@ ...@@ -28,7 +28,10 @@
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp> #include <migraphx/gpu/ck.hpp>
#endif
#include <migraphx/gpu/fuse_mlir.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -128,26 +131,49 @@ struct pre_gemm_softmax_gemm : gemm_softmax_gemm ...@@ -128,26 +131,49 @@ struct pre_gemm_softmax_gemm : gemm_softmax_gemm
}; };
MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm); MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm);
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) auto is_ck_gemm()
{ {
if(ins->name() != "dot") return match::make_basic_pred_matcher([=](instruction_ref ins) {
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
if(not enabled(MIGRAPHX_ENABLE_CK{}))
return false;
if(ins->name() != "dot")
return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
return true;
#else
(void)ins;
return false; return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type())) #endif
return false; });
return true; }
auto is_mlir_gemm()
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(not mlir_attention_enabled())
return false;
if(ins->name() != "dot")
return false;
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return pre_gemm_softmax_gemm::is_mlir_supported_type(i->get_shape().type());
});
});
} }
struct find_gemm_softmax_gemm struct find_gemm_softmax_gemm
{ {
auto matcher() const auto matcher() const
{ {
auto gemm1 = auto gemm1 = match::skip(match::name("contiguous"))(
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1")));
auto mul = match::name("mul")( auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax"); auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax)); return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))(
match::arg(0)(softmax));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...@@ -182,8 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const ...@@ -182,8 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{}); match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{}); match::find_matches(mpm.get_module(), find_add_layernorm{});
if(enabled(MIGRAPHX_ENABLE_CK{})) match::find_matches(mpm, find_gemm_softmax_gemm{});
match::find_matches(mpm, find_gemm_softmax_gemm{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -33,8 +33,9 @@ rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION}) ...@@ -33,8 +33,9 @@ rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})
find_path(BLAZE_INCLUDE blaze/Blaze.h) find_path(BLAZE_INCLUDE blaze/Blaze.h)
rocm_clang_tidy_check(migraphx_ref) rocm_clang_tidy_check(migraphx_ref)
target_link_libraries(migraphx_ref PRIVATE Threads::Threads)
target_link_libraries(migraphx_ref PUBLIC migraphx) target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE}) target_include_directories(migraphx_ref SYSTEM PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS) target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
migraphx_generate_export_header(migraphx_ref) migraphx_generate_export_header(migraphx_ref)
......
...@@ -38,7 +38,11 @@ protobuf_generate_cpp( ...@@ -38,7 +38,11 @@ protobuf_generate_cpp(
) )
add_library(tf-proto STATIC ${PROTO_SRCS}) add_library(tf-proto STATIC ${PROTO_SRCS})
target_include_directories(tf-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR}) target_include_directories(tf-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(tf-proto PRIVATE -w) if(MSVC)
target_compile_options(tf-proto PRIVATE /w)
else()
target_compile_options(tf-proto PRIVATE -w)
endif()
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
...@@ -49,7 +53,10 @@ target_include_directories(migraphx_tf PRIVATE include) ...@@ -49,7 +53,10 @@ target_include_directories(migraphx_tf PRIVATE include)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf) set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_tf) rocm_clang_tidy_check(migraphx_tf)
target_link_libraries(migraphx_tf PRIVATE tf-proto "-Wl,--exclude-libs,ALL") target_link_libraries(migraphx_tf PRIVATE tf-proto)
if(NOT WIN32)
target_link_libraries(migraphx_tf PRIVATE "-Wl,--exclude-libs,ALL")
endif()
target_link_libraries(migraphx_tf PUBLIC migraphx) target_link_libraries(migraphx_tf PUBLIC migraphx)
rocm_install_targets( rocm_install_targets(
......
...@@ -31,8 +31,18 @@ ...@@ -31,8 +31,18 @@
#include <sstream> #include <sstream>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <sys/types.h>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#undef getpid
// cppcheck-suppress [definePrefix, defineUpperCase]
#define getpid _getpid
#else
#include <unistd.h> #include <unistd.h>
#include <sys/types.h>
#endif
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails) ...@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails)
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh")); auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh}); mm->add_return({tanh});
} }
migraphx::program p2(p1); // This pass should not fuse as int32_t tanh isn't supported.
// This pass should do nothing as int32_t tanh isn't supported.
run_pass(p1); run_pass(p1);
EXPECT(p1 == p2); auto* mm = p1.get_main_module();
bool has_pointwise =
std::any_of(mm->begin(), mm->end(), [&](const auto& i) { return i.name() == "pointwise"; });
EXPECT(has_pointwise);
} }
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment