Unverified Commit c4cee345 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Merge branch 'develop' into rocblas_fp8

parents c40a39c3 eafd55de
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_JIT_SCATTER_HPP
#define MIGRAPHX_GUARD_JIT_SCATTER_HPP
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <typename Derived>
struct scatter_compiler : compiler<Derived>
{
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
const auto inputs =
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()});
hip_compile_options options;
options.set_launch_params(op.to_value(), compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = derived().get_kernel_name(op);
options.virtual_inputs = inputs;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options.params += "-Wno-float-equal";
const auto src = derived().make_interpolated_string(op);
return prepend_copy_data_to_output(compile_hip_code_object(src, options));
}
compiler_replace prepend_copy_data_to_output(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
std::string get_kernel_name(const operation& op) const { return op.name() + "_kernel"; }
const Derived& derived() const { return static_cast<const Derived&>(*this); }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -21,11 +21,7 @@ ...@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/compiler.hpp> #include "scatter.hpp"
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void* ...@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__"; )__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler> struct scatternd_compiler : scatter_compiler<scatternd_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const
{ {
return {"scatternd_none", "scatternd_add", "scatternd_mul"}; return {
"scatternd_none", "scatternd_add", "scatternd_mul", "scatternd_min", "scatternd_max"};
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const std::string make_interpolated_string(const operation& op) const
{ {
hip_compile_options options; const auto reduction = op.name().substr(std::char_traits<char>::length("scatternd_"));
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements())); return interpolate_string(scatternd_kernel, {{"reduction", "assign_" + reduction}});
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const std::string get_kernel_name(const operation&) const { return "scatternd_kernel"; }
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(
ctx,
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -22,8 +22,13 @@ ...@@ -22,8 +22,13 @@
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx { namespace migraphx {
template <typename To, typename From> template <typename To,
typename From,
MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept inline constexpr To bit_cast(From fr) noexcept
{ {
static_assert(sizeof(To) == sizeof(From)); static_assert(sizeof(To) == sizeof(From));
......
...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, ...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens; auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens; auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back(); auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = size_t num_slices =
accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{}); accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims, size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(), data_shape_lens.end(),
1, 1,
op::product{}); op::product{});
const std::size_t num_batches = const size_t num_batches =
accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{}); accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
const std::size_t data_batch_stride = const size_t data_batch_stride =
accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{}); accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const auto num_slices_per_batch = num_slices / num_batches; const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) { ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data(); const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size; const size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch; const size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims); auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0; size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx) for(size_t idx = 0; idx < num_slice_dims; ++idx)
{ {
int64_t index = slice_indices[idx]; int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx; const size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx]; const auto input_dim = data_shape_lens[input_dim_idx];
MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim)); index < static_cast<int64_t>(input_dim));
if(index < 0) if(index < 0)
index += input_dim; index += input_dim;
std::size_t size_from_slice_dims = size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1, accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims, data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size, slice_size,
......
...@@ -54,12 +54,12 @@ __device__ void generic_binary_layernorm( ...@@ -54,12 +54,12 @@ __device__ void generic_binary_layernorm(
using value_type = typename Input1::type; using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>; using vec_value_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = static_cast<vec_value_type>(1.0 / relements);
constexpr auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r); auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, auto means = r.reduce(op::sum{},
make_array<vec_value_type>(static_cast<vec_value_type>(0), make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
static_cast<vec_value_type>(0)),
[&](auto x) { [&](auto x) {
auto x_out = x * relements_r; auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing // dividing x by sqrt(relements) before squaring allows computing
...@@ -71,7 +71,7 @@ __device__ void generic_binary_layernorm( ...@@ -71,7 +71,7 @@ __device__ void generic_binary_layernorm(
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = static_cast<value_type>(eps); value_type eps_val = implicit_conversion(eps);
r.inner([&](auto& y, auto x, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x; auto m = x - mean_x;
......
...@@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where) ...@@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U> template <class T, class U>
constexpr auto convert(U v) constexpr auto convert(U v)
{ {
return vec_transform(v)([](auto x) { return static_cast<T>(x); }); return vec_transform(v)([](auto x) -> T { return static_cast<T>(x); });
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output) ...@@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output)
auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input); auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = static_cast<otype>(x / batch_sum); })(output, exp_in); r.inner([&](auto& y, auto x) { y = implicit_conversion(x / batch_sum); })(output, exp_in);
}); });
} }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment