Commit 351fde4d authored by Paul's avatar Paul
Browse files

Handle non-const local

parent c78ce73d
......@@ -91,7 +91,7 @@ __device__ auto& array2vec(T& x)
template <class T, class... Ts>
constexpr auto array_for_each(T& x, Ts&... xs)
{
MIGRAPHX_ASSERT((x.size() == xs.size() and ...));
MIGRAPHX_ASSERT(((x.size() == xs.size()) and ...));
return [&](auto f) {
constexpr auto size = decltype(x.size()){};
if constexpr((is_vectorizable<typename T::value_type>() or
......
......@@ -28,9 +28,54 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx {
#if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL)
#define MIGRAPHX_NGROUP ((MIGRAPHX_NGLOBAL + MIGRAPHX_NLOCAL - 1) / MIGRAPHX_NLOCAL)
#endif
inline __device__ __attribute__((const)) index_int compute_global_size()
{
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
#else
return blockDim.x * gridDim.x; // NOLINT
#endif
}
inline __device__ __attribute__((const)) index_int compute_local_size()
{
#ifdef MIGRAPHX_NLOCAL
const auto nlocal = MIGRAPHX_NLOCAL;
#else
const auto nlocal = blockDim.x;
#endif
#ifdef MIGRAPHX_NGROUP
const auto ngroup = MIGRAPHX_NGROUP;
#else
const auto ngroup = gridDim.x;
#endif
const auto group_id = blockIdx.x;
const auto nglobal = compute_global_size();
if (group_id == ngroup - 1)
{
return nglobal % nlocal;
}
else
{
return nlocal; // NOLINT
}
}
#ifdef MIGRAPHX_NGROUP
// If global is divisible by local then local can be a const
#if (MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_CONST_LOCAL 1
#endif
#endif
struct index
{
index_int global = 0;
......@@ -42,16 +87,16 @@ struct index
#else
__device__ index_int nglobal() const
{
return blockDim.x * gridDim.x; // NOLINT
return compute_global_size(); // NOLINT
}
#endif
#ifdef MIGRAPHX_NLOCAL
#ifdef MIGRAPHX_HAS_CONST_LOCAL
constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const { return {}; }
#else
__device__ index_int nlocal() const
{
return blockDim.x; // NOLINT
return compute_local_size(); // NOLINT
}
#endif
template <class N, class Stride>
......@@ -63,6 +108,7 @@ struct index
template <class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{
MIGRAPHX_ASSERT(start < stride);
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and
max_stride_iterations(n, stride) == 1)
{
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_conv_group_add : verify_program<test_conv_group_add>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 68, 28, 28}};
auto x = mm->add_parameter("x", s);
auto w = mm->add_parameter("w", {migraphx::shape::float_type, {68, 17, 1, 1}});
auto b = mm->add_parameter("b", {migraphx::shape::float_type, {68}});
auto conv = mm->add_instruction(migraphx::make_op("convolution", {{"group", 4}}), x, w);
auto bb = mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 68, 28, 28}}}), b);
mm->add_instruction(migraphx::make_op("add"), conv, bb);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment