Unverified Commit c4cee345 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Merge branch 'develop' into rocblas_fp8

parents c40a39c3 eafd55de
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_JIT_SCATTER_HPP
#define MIGRAPHX_GUARD_JIT_SCATTER_HPP
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <typename Derived>
struct scatter_compiler : compiler<Derived>
{
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
const auto inputs =
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()});
hip_compile_options options;
options.set_launch_params(op.to_value(), compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = derived().get_kernel_name(op);
options.virtual_inputs = inputs;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options.params += "-Wno-float-equal";
const auto src = derived().make_interpolated_string(op);
return prepend_copy_data_to_output(compile_hip_code_object(src, options));
}
compiler_replace prepend_copy_data_to_output(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
std::string get_kernel_name(const operation& op) const { return op.name() + "_kernel"; }
const Derived& derived() const { return static_cast<const Derived&>(*this); }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include "scatter.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler>
struct scatternd_compiler : scatter_compiler<scatternd_compiler>
{
std::vector<std::string> names() const
{
return {"scatternd_none", "scatternd_add", "scatternd_mul"};
return {
"scatternd_none", "scatternd_add", "scatternd_mul", "scatternd_min", "scatternd_max"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
std::string make_interpolated_string(const operation& op) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
const auto reduction = op.name().substr(std::char_traits<char>::length("scatternd_"));
return interpolate_string(scatternd_kernel, {{"reduction", "assign_" + reduction}});
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(
ctx,
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
std::string get_kernel_name(const operation&) const { return "scatternd_kernel"; }
};
} // namespace gpu
......
......@@ -22,8 +22,13 @@
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
template <typename To, typename From>
template <typename To,
typename From,
MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
......
......@@ -365,15 +365,6 @@ struct float8
inline __device__ constexpr float8& operator=(const float8& rhs) = default;
inline __device__ constexpr float8& operator=(float8&& rhs) noexcept = default;
inline __device__ constexpr bool operator==(const float8& rhs) const
{
if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false;
else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data))
return true;
return false;
}
inline __device__ constexpr bool operator<(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
......@@ -403,12 +394,21 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_FABS(T) \
inline constexpr __device__ T fabs(T v) \
{ \
/*NOLINTNEXTLINE*/ \
v.data = v.data & 0x7f; \
return v; \
#define MIGRAPHX_FP8_OTHER_OPS(T) \
inline constexpr __device__ T fabs(T v) \
{ \
/*NOLINTNEXTLINE*/ \
v.data = v.data & 0x7f; \
return v; \
} \
inline __device__ constexpr bool operator==(const T& lhs, const T& rhs) \
{ \
if(rhs.is_nan() or rhs.is_inf() or lhs.is_nan() or lhs.is_inf()) \
return false; \
else if((rhs.is_zero() and lhs.is_zero()) or (lhs.data == rhs.data)) \
return true; \
return false; \
}
// NOLINTNEXTLINE
......@@ -417,11 +417,10 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \
MIGRAPHX_FP8_FABS(T)
MIGRAPHX_FP8_OTHER_OPS(T)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2fnuz)
......@@ -447,7 +446,7 @@ class numeric_limits<fp8e4m3fnuz>
{
return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01
static constexpr __device__ fp8e4m3fnuz min()
{
return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits());
......@@ -475,7 +474,7 @@ class numeric_limits<fp8e4m3fn>
}
static constexpr __device__ fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01
static constexpr __device__ fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); }
static constexpr __device__ fp8e4m3fn lowest()
......@@ -503,8 +502,10 @@ class numeric_limits<fp8e5m2fnuz>
{
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static constexpr __device__ fp8e5m2fnuz min()
{
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
......@@ -529,8 +530,7 @@ class numeric_limits<fp8e5m2>
}
static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
......@@ -540,23 +540,27 @@ class numeric_limits<fp8e5m2>
} // namespace fp8
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_MIN_MAX(T) \
template <> \
constexpr T numeric_max<T, void>() \
{ \
return fp8::numeric_limits<T>::max(); \
} \
template <> \
constexpr T numeric_lowest<T>() \
{ \
return fp8::numeric_limits<T>::lowest(); \
}
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fnuz);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2fnuz);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fn);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2);
template <class T,
MIGRAPHX_REQUIRES(is_same<T, fp8::fp8e4m3fnuz>{} or is_same<T, fp8::fp8e5m2fnuz>{} or
is_same<T, fp8::fp8e4m3fn>{} or is_same<T, fp8::fp8e5m2>{})>
constexpr T numeric_max(migraphx::fp8::f8_type unused = migraphx::fp8::f8_type::fp8)
{
// unused parameter is added to make this numeric_max different overload definition
// compared to numeric_max defined in type_traits.hpp
(void)(unused);
return fp8::numeric_limits<T>::max();
}
template <class T,
MIGRAPHX_REQUIRES(is_same<T, fp8::fp8e4m3fnuz>{} or is_same<T, fp8::fp8e5m2fnuz>{} or
is_same<T, fp8::fp8e4m3fn>{} or is_same<T, fp8::fp8e5m2>{})>
constexpr T numeric_lowest(migraphx::fp8::f8_type unused = migraphx::fp8::f8_type::fp8)
{
// unused parameter is added to make this numeric_lowest different overload definition
// compared to numeric_lowest defined in type_traits.hpp
(void)(unused);
return fp8::numeric_limits<T>::lowest();
}
} // namespace migraphx
// =================================================================================================
#if defined(__clang__)
......
......@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices =
size_t num_slices =
accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
op::product{});
const std::size_t num_batches =
size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
op::product{});
const size_t num_batches =
accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
const std::size_t data_batch_stride =
const size_t data_batch_stride =
accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch;
const size_t j = i / slice_size;
const size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
size_t relative_slice_offset = 0;
for(size_t idx = 0; idx < num_slice_dims; ++idx)
{
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
......
......@@ -54,12 +54,12 @@ __device__ void generic_binary_layernorm(
using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = static_cast<vec_value_type>(1.0 / relements);
constexpr auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{},
make_array<vec_value_type>(static_cast<vec_value_type>(0),
static_cast<vec_value_type>(0)),
make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
[&](auto x) {
auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing
......@@ -71,7 +71,7 @@ __device__ void generic_binary_layernorm(
auto mean_x = means[0];
auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = static_cast<value_type>(eps);
value_type eps_val = implicit_conversion(eps);
r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x;
......
......@@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U>
constexpr auto convert(U v)
{
return vec_transform(v)([](auto x) { return static_cast<T>(x); });
return vec_transform(v)([](auto x) -> T { return static_cast<T>(x); });
}
} // namespace migraphx
......
......@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx {
......@@ -54,9 +55,9 @@ __device__ void pad(const index& idx,
if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
}))
output[multi] = otype(pad_val);
output[multi] = implicit_conversion(pad_val);
else
output[multi] = otype(input[input_idx]);
output[multi] = implicit_conversion(input[input_idx]);
});
}
......
......@@ -62,7 +62,7 @@ struct avg_pool
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{
return (y == 0) ? static_cast<T>(0.0) : static_cast<T>(x / y);
return (y == 0) ? T{0.0} : T{x / y};
}
};
......@@ -77,7 +77,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
{
if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{
return static_cast<ret_type>(0);
return implicit_conversion(0);
}
xy[ii] = migraphx::max(xy[ii], 0.0f);
......@@ -93,18 +93,17 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0];
float lx = xy[1] - low[1];
float hy = 1.0f - ly;
float hx = 1.0f - lx;
array<ret_type, 4> ws = {static_cast<ret_type>(hy * hx),
static_cast<ret_type>(hy * lx),
static_cast<ret_type>(ly * hx),
static_cast<ret_type>(ly * lx)};
float ly = xy[0] - low[0];
float lx = xy[1] - low[1];
float hy = 1.0f - ly;
float hx = 1.0f - lx;
// do calculations in floating point and convert final result to required type
array<float, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23);
return implicit_conversion(pooling(v01, v23));
}
template <class Iterator, class Op>
......@@ -153,7 +152,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto x = x_t.begin();
const auto rois = rois_t.begin();
const auto ind = ind_t.begin();
using ytype = typename W::type;
// input shape
auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1];
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#include <migraphx/kernels/types.hpp>
namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicAdd(&x, y);
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
T old = x;
T assumed;
do
{
assumed = old;
old = atomicCAS(&x, assumed, assumed * y);
} while(assumed != old);
}
};
struct assign_max
{
template <typename T, typename U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicMax(&x, y);
}
};
struct assign_min
{
template <typename T, typename U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicMin(&x, y);
}
};
} // namespace migraphx
#endif
......@@ -26,36 +26,10 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/scatter_reduction_modes.hpp>
namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x += y;
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x *= y;
}
};
template <class T, class U, class V, class F>
__device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f)
{
......
......@@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output)
auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
auto batch_sum =
r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = static_cast<otype>(x / batch_sum); })(output, exp_in);
r.inner([&](auto& y, auto x) { y = implicit_conversion(x / batch_sum); })(output, exp_in);
});
}
......
......@@ -251,7 +251,7 @@ constexpr T numeric_max()
}
template <class T>
constexpr T numeric_lowest()
constexpr auto numeric_lowest() -> decltype(numeric_max<T>())
{
if constexpr(is_integral<T>{})
{
......
......@@ -207,7 +207,7 @@ struct implicit_conversion_op
template <class U>
constexpr operator U() const
{
return x;
return static_cast<U>(x);
}
};
......
......@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_LIMIT);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
......@@ -796,7 +797,9 @@ struct mlir_program
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
for(auto i : range(mlirRockTuningGetNumParams(params.get())))
const auto limit =
value_of(MIGRAPHX_MLIR_TUNE_LIMIT{}, std::numeric_limits<std::size_t>::max());
for(auto i : range(std::min<std::size_t>(limit, mlirRockTuningGetNumParams(params.get()))))
{
mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get()))
......@@ -1032,6 +1035,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
mlir_program mp;
mp.set_gpu_properties(migraphx_ctx);
mp.parse(m);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
return mp.get_tuning_config(exhaustive);
}
......
......@@ -28,7 +28,10 @@
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp>
#endif
#include <migraphx/gpu/fuse_mlir.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -128,26 +131,49 @@ struct pre_gemm_softmax_gemm : gemm_softmax_gemm
};
MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm);
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto is_ck_gemm()
{
if(ins->name() != "dot")
return match::make_basic_pred_matcher([=](instruction_ref ins) {
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
if(not enabled(MIGRAPHX_ENABLE_CK{}))
return false;
if(ins->name() != "dot")
return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
return true;
#else
(void)ins;
return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
return true;
#endif
});
}
auto is_mlir_gemm()
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(not mlir_attention_enabled())
return false;
if(ins->name() != "dot")
return false;
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return pre_gemm_softmax_gemm::is_mlir_supported_type(i->get_shape().type());
});
});
}
struct find_gemm_softmax_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto gemm1 = match::skip(match::name("contiguous"))(
match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax));
return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))(
match::arg(0)(softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
......@@ -182,8 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
if(enabled(MIGRAPHX_ENABLE_CK{}))
match::find_matches(mpm, find_gemm_softmax_gemm{});
match::find_matches(mpm, find_gemm_softmax_gemm{});
}
} // namespace gpu
......
......@@ -33,8 +33,9 @@ rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})
find_path(BLAZE_INCLUDE blaze/Blaze.h)
rocm_clang_tidy_check(migraphx_ref)
target_link_libraries(migraphx_ref PRIVATE Threads::Threads)
target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE})
target_include_directories(migraphx_ref SYSTEM PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
migraphx_generate_export_header(migraphx_ref)
......
......@@ -38,7 +38,11 @@ protobuf_generate_cpp(
)
add_library(tf-proto STATIC ${PROTO_SRCS})
target_include_directories(tf-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(tf-proto PRIVATE -w)
if(MSVC)
target_compile_options(tf-proto PRIVATE /w)
else()
target_compile_options(tf-proto PRIVATE -w)
endif()
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
......@@ -49,7 +53,10 @@ target_include_directories(migraphx_tf PRIVATE include)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_tf)
target_link_libraries(migraphx_tf PRIVATE tf-proto "-Wl,--exclude-libs,ALL")
target_link_libraries(migraphx_tf PRIVATE tf-proto)
if(NOT WIN32)
target_link_libraries(migraphx_tf PRIVATE "-Wl,--exclude-libs,ALL")
endif()
target_link_libraries(migraphx_tf PUBLIC migraphx)
rocm_install_targets(
......
......@@ -31,8 +31,18 @@
#include <sstream>
#include <iostream>
#include <string>
#include <sys/types.h>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#undef getpid
// cppcheck-suppress [definePrefix, defineUpperCase]
#define getpid _getpid
#else
#include <unistd.h>
#include <sys/types.h>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails)
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh});
}
migraphx::program p2(p1);
// This pass should do nothing as int32_t tanh isn't supported.
// This pass should not fuse as int32_t tanh isn't supported.
run_pass(p1);
EXPECT(p1 == p2);
auto* mm = p1.get_main_module();
bool has_pointwise =
std::any_of(mm->begin(), mm->end(), [&](const auto& i) { return i.name() == "pointwise"; });
EXPECT(has_pointwise);
}
int main(int argc, const char* argv[])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment