Commit ffcb68b4 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mlir-attention

parents ee88607c 7604ecf5
...@@ -936,7 +936,7 @@ void program::perf_report(std::ostream& os, ...@@ -936,7 +936,7 @@ void program::perf_report(std::ostream& os,
os << std::endl; os << std::endl;
os << "Batch size: " << batch << std::endl; os << "Batch size: " << batch << std::endl;
os << "Rate: " << rate * batch << "/sec" << std::endl; os << "Rate: " << rate * batch << "inferences/sec" << std::endl;
os << "Total time: " << total_time << "ms" << std::endl; os << "Total time: " << total_time << "ms" << std::endl;
os << "Total instructions time: " << total_instruction_time << "ms" << std::endl; os << "Total instructions time: " << total_instruction_time << "ms" << std::endl;
os << "Overhead time: " << overhead_time << "ms" os << "Overhead time: " << overhead_time << "ms"
......
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS);
void apply_quantizelinear(module& m, instruction_ref ins) void apply_quantizelinear(module& m, instruction_ref ins)
{ {
assert(ins->name() == "quantizelinear"); assert(ins->name() == "quantizelinear");
...@@ -63,8 +65,21 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -63,8 +65,21 @@ void apply_quantizelinear(module& m, instruction_ref ins)
min_quant = qt.min(); min_quant = qt.min();
}); });
auto s = add_zero_point->get_shape(); auto s = add_zero_point->get_shape();
auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}}); instruction_ref min_arg;
auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}}); instruction_ref max_arg;
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data));
}
else
{
min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
}
auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg}); auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
m.replace_instruction( m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate); ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum ...@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name()); int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::argmax(
ctx.get_stream().get(), args.back(), args.front(), tuned_axis, op.select_last_index);
return args.back(); return args.back();
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum ...@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name()); int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmin(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::argmin(
ctx.get_stream().get(), args.back(), args.front(), tuned_axis, op.select_last_index);
return args.back(); return args.back();
} }
......
...@@ -139,6 +139,12 @@ void hip_compile_options::set_launch_params( ...@@ -139,6 +139,12 @@ void hip_compile_options::set_launch_params(
global = compute_global(local); global = compute_global(local);
} }
static bool hip_accept_non_uniform_wg()
{
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
return non_uniform_wg;
}
std::function<std::size_t(std::size_t local)> std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over) compute_global_for(context& ctx, std::size_t n, std::size_t over)
{ {
...@@ -146,11 +152,12 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -146,11 +152,12 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) { return [n, over, max_global](std::size_t local) {
// hip require global workitems multiple of local workitems. It may degrade performance. std::size_t num_elements = n;
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available. if(not hip_accept_non_uniform_wg())
// https://reviews.llvm.org/D155213 {
std::size_t num_elements = ((n + local - 1) / local) * local; num_elements = (1 + (n - 1) / local) * local;
std::size_t groups = (num_elements + local - 1) / local; }
std::size_t groups = 1 + (num_elements - 1) / local;
std::size_t max_blocks = max_global / local; std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local; std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return std::min(nglobal, num_elements); return std::min(nglobal, num_elements);
...@@ -183,6 +190,11 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -183,6 +190,11 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs); generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.emplace_back("args.hpp", args_hpp); srcs.emplace_back("args.hpp", args_hpp);
if(options.global % options.local != 0 and hip_accept_non_uniform_wg())
options.params += " -fno-offload-uniform-block";
else
assert(options.global % options.local == 0);
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
......
...@@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_BENCHMARKING);
struct precompile_op struct precompile_op
{ {
...@@ -179,13 +180,28 @@ struct compile_plan ...@@ -179,13 +180,28 @@ struct compile_plan
MIGRAPHX_THROW("Multiple kernels without config"); MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl; << std::endl;
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
std::cout << "Problem: " << config->problem << std::endl;
std::vector<double> times; std::vector<double> times;
times.reserve(results.size()); times.reserve(results.size());
std::transform( std::transform(results.begin(),
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) { results.end(),
config->solutions.begin(),
std::back_inserter(times),
[&](const auto& cr, const auto& solution) {
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
std::cout << "Benchmarking solution: " << solution << std::endl;
if(not cr.has_value()) if(not cr.has_value())
{
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max(); return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20); }
auto t = time_op(
*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
std::cout << t << "ms" << std::endl;
return t;
}); });
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl; std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void argmax(hipStream_t stream,
const argument& result,
const argument& arg,
int64_t axis,
bool select_last_index)
{ {
arg_op(argmax_op{}, stream, result, arg, axis); if(select_last_index)
arg_op(argmax_op_last_index{}, stream, result, arg, axis);
else
arg_op(argmax_op_first_index{}, stream, result, arg, axis);
} }
} // namespace device } // namespace device
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void argmin(hipStream_t stream,
const argument& result,
const argument& arg,
int64_t axis,
bool select_last_index)
{ {
arg_op(argmin_op{}, stream, result, arg, axis); if(select_last_index)
arg_op(argmin_op_last_index{}, stream, result, arg, axis);
else
arg_op(argmin_op_first_index{}, stream, result, arg, axis);
} }
} // namespace device } // namespace device
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/gpu/device_name.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -92,6 +93,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -92,6 +93,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto m = a.lens()[a.lens().size() - 2]; auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back(); auto n = b.lens().back();
auto k = a.lens().back(); auto k = a.lens().back();
auto batch_size = std::accumulate(
a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
// Integer gemms must be divisible by 4 in ck // Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type())) if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{ {
...@@ -102,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -102,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
if(k % 4 != 0) if(k % 4 != 0)
return false; return false;
} }
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy auto device_name = trim(split_string(get_device_name(), ':').front());
// to avoid poor-performing GEMM kernels from CK if(device_name == "gfx940")
// To-do: Investigate a more precise strategy {
if(ins->get_shape().type() == shape::half_type)
{
if(batch_size >= 64)
return m < 2048 or k <= 64 or n <= 384 or n >= 2048;
return true;
}
return true;
}
return k <= 2048; return k <= 2048;
} }
...@@ -140,6 +151,10 @@ struct find_ck_gemm_pointwise ...@@ -140,6 +151,10 @@ struct find_ck_gemm_pointwise
return not input->inputs().empty() and input->inputs().front()->name() == "capture"; return not input->inputs().empty() and input->inputs().front()->name() == "capture";
})) }))
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
......
...@@ -199,9 +199,9 @@ struct miopen_convolution ...@@ -199,9 +199,9 @@ struct miopen_convolution
// MIOpen has APIs to pass pre-allocated buffers starting from rocm-5.6 // MIOpen has APIs to pass pre-allocated buffers starting from rocm-5.6
preallocate = true; preallocate = true;
#endif #endif
auto x = preallocate ? to_gpu(generate_argument(x_shape)) : inputs[0]; auto x = preallocate ? to_gpu(generate_argument(x_shape)) : argument{inputs[0]};
auto w = preallocate ? to_gpu(generate_argument(w_shape)) : inputs[1]; auto w = preallocate ? to_gpu(generate_argument(w_shape)) : argument{inputs[1]};
auto y = preallocate ? allocate_gpu(output_shape) : inputs[2]; auto y = preallocate ? allocate_gpu(output_shape) : argument{inputs[2]};
auto workspace = auto workspace =
preallocate ? allocate_gpu(workspace_shape) : migraphx::argument(workspace_shape); preallocate ? allocate_gpu(workspace_shape) : migraphx::argument(workspace_shape);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i) ...@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
return {v, i}; return {v, i};
} }
struct argmax_op struct argmax_op_first_index
{ {
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
...@@ -73,7 +73,25 @@ struct argmax_op ...@@ -73,7 +73,25 @@ struct argmax_op
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
}; };
struct argmin_op struct argmax_op_last_index
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val > y.val)
return x;
else if(x.val < y.val)
return y;
else
{
return (x.index > y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
};
struct argmin_op_first_index
{ {
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
...@@ -91,6 +109,24 @@ struct argmin_op ...@@ -91,6 +109,24 @@ struct argmin_op
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
}; };
struct argmin_op_last_index
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val < y.val)
return x;
else if(x.val > y.val)
return y;
else
{
return (x.index > y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
};
template <class Op> template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,7 +36,8 @@ namespace device { ...@@ -36,7 +36,8 @@ namespace device {
void MIGRAPHX_DEVICE_EXPORT argmax(hipStream_t stream, void MIGRAPHX_DEVICE_EXPORT argmax(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg, const argument& arg,
int64_t axis); int64_t axis,
bool select_last_index);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,7 +36,8 @@ namespace device { ...@@ -36,7 +36,8 @@ namespace device {
void MIGRAPHX_DEVICE_EXPORT argmin(hipStream_t stream, void MIGRAPHX_DEVICE_EXPORT argmin(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg, const argument& arg,
int64_t axis); int64_t axis,
bool select_last_index);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -31,6 +31,14 @@ ...@@ -31,6 +31,14 @@
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
extern "C" __device__ size_t __ockl_get_enqueued_local_size(uint); // NOLINT
extern "C" __device__ size_t __ockl_get_local_size(uint); // NOLINT
#pragma clang diagnostic pop
#endif
namespace migraphx { namespace migraphx {
#if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL) #if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL)
...@@ -49,39 +57,33 @@ inline __device__ __attribute__((const)) index_int compute_global_size() ...@@ -49,39 +57,33 @@ inline __device__ __attribute__((const)) index_int compute_global_size()
#endif #endif
} }
// We cant just use blockDim.x to get the local size since its broken on hip #ifdef MIGRAPHX_NGROUP
// when global is not divisible by local size. In this case, we calulate the // If global is divisible by local then local can be a const
// size for the last group. #if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_HAS_CONST_LOCAL 1
#endif
#endif
inline __device__ __attribute__((const)) index_int compute_local_size() inline __device__ __attribute__((const)) index_int compute_local_size()
{ {
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_HAS_CONST_LOCAL
const auto nlocal = MIGRAPHX_NLOCAL; return MIGRAPHX_NLOCAL;
#else
const auto nlocal = blockDim.x; // NOLINT
#endif
#ifdef MIGRAPHX_NGROUP
const auto ngroup = MIGRAPHX_NGROUP;
#else #else
const auto ngroup = gridDim.x; // NOLINT // Returns block size. For the non-uniform block it returns the size of the non-uniform block.
return __ockl_get_local_size(0); // NOLINT
#endif #endif
const auto group_id = blockIdx.x; // NOLINT
const auto nglobal = compute_global_size();
if(group_id == ngroup - 1)
{
return 1 + (nglobal - 1) % nlocal;
}
else
{
return nlocal; // NOLINT
}
} }
#ifdef MIGRAPHX_NGROUP inline __device__ __attribute__((const)) index_int compute_max_local_size()
// If global is divisible by local then local can be a const {
#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1) #ifdef MIGRAPHX_LOCAL
#define MIGRAPHX_HAS_CONST_LOCAL 1 return MIGRAPHX_NLOCAL;
#endif #else
// Returns the block size. When workgrop has non-uniform block, this returns size of the uniform
// block.
return __ockl_get_enqueued_local_size(0); // NOLINT
#endif #endif
}
struct index struct index
{ {
...@@ -126,8 +128,8 @@ struct index ...@@ -126,8 +128,8 @@ struct index
#else #else
__device__ index_int max_nlocal() const __device__ index_int max_nlocal() const
{ {
MIGRAPHX_ASSERT(blockDim.x > 0); MIGRAPHX_ASSERT(compute_max_local_size() > 0);
return blockDim.x; return compute_max_local_size();
} }
#endif #endif
...@@ -249,7 +251,8 @@ struct index ...@@ -249,7 +251,8 @@ struct index
#endif #endif
inline __device__ __attribute__((const)) index make_index() inline __device__ __attribute__((const)) index make_index()
{ {
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT return index{
blockIdx.x * compute_max_local_size() + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -55,7 +55,7 @@ struct allocate ...@@ -55,7 +55,7 @@ struct allocate
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
......
...@@ -60,7 +60,7 @@ struct concat ...@@ -60,7 +60,7 @@ struct concat
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
...@@ -104,7 +104,7 @@ struct allocate ...@@ -104,7 +104,7 @@ struct allocate
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
......
...@@ -55,7 +55,7 @@ struct allocate ...@@ -55,7 +55,7 @@ struct allocate
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
......
...@@ -57,7 +57,7 @@ struct normalize_test_op ...@@ -57,7 +57,7 @@ struct normalize_test_op
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
......
6d7bc2a097a1a08541cd0d4628831c79ab8092d5 635d3faa3b3908d2806d009dc6872152cfcfcdda
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment