Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
351fde4d
Commit
351fde4d
authored
Sep 08, 2022
by
Paul
Browse files
Handle non-const local
parent
c78ce73d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
95 additions
and
4 deletions
+95
-4
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+49
-3
test/verify/test_conv_group_add.cpp
test/verify/test_conv_group_add.cpp
+45
-0
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
351fde4d
...
...
@@ -91,7 +91,7 @@ __device__ auto& array2vec(T& x)
template
<
class
T
,
class
...
Ts
>
constexpr
auto
array_for_each
(
T
&
x
,
Ts
&
...
xs
)
{
MIGRAPHX_ASSERT
((
x
.
size
()
==
xs
.
size
()
and
...));
MIGRAPHX_ASSERT
((
(
x
.
size
()
==
xs
.
size
()
)
and
...));
return
[
&
](
auto
f
)
{
constexpr
auto
size
=
decltype
(
x
.
size
()){};
if
constexpr
((
is_vectorizable
<
typename
T
::
value_type
>
()
or
...
...
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
351fde4d
...
...
@@ -28,9 +28,54 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
#if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL)
#define MIGRAPHX_NGROUP ((MIGRAPHX_NGLOBAL + MIGRAPHX_NLOCAL - 1) / MIGRAPHX_NLOCAL)
#endif
inline
__device__
__attribute__
((
const
))
index_int
compute_global_size
()
{
#ifdef MIGRAPHX_NGLOBAL
return
MIGRAPHX_NGLOBAL
;
#else
return
blockDim
.
x
*
gridDim
.
x
;
// NOLINT
#endif
}
inline
__device__
__attribute__
((
const
))
index_int
compute_local_size
()
{
#ifdef MIGRAPHX_NLOCAL
const
auto
nlocal
=
MIGRAPHX_NLOCAL
;
#else
const
auto
nlocal
=
blockDim
.
x
;
#endif
#ifdef MIGRAPHX_NGROUP
const
auto
ngroup
=
MIGRAPHX_NGROUP
;
#else
const
auto
ngroup
=
gridDim
.
x
;
#endif
const
auto
group_id
=
blockIdx
.
x
;
const
auto
nglobal
=
compute_global_size
();
if
(
group_id
==
ngroup
-
1
)
{
return
nglobal
%
nlocal
;
}
else
{
return
nlocal
;
// NOLINT
}
}
#ifdef MIGRAPHX_NGROUP
// If global is divisible by local then local can be a const
#if (MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_CONST_LOCAL 1
#endif
#endif
struct
index
{
index_int
global
=
0
;
...
...
@@ -42,16 +87,16 @@ struct index
#else
__device__
index_int
nglobal
()
const
{
return
blockDim
.
x
*
gridDim
.
x
;
// NOLINT
return
compute_global_size
()
;
// NOLINT
}
#endif
#ifdef MIGRAPHX_
N
LOCAL
#ifdef MIGRAPHX_
HAS_CONST_
LOCAL
constexpr
index_constant
<
MIGRAPHX_NLOCAL
>
nlocal
()
const
{
return
{};
}
#else
__device__
index_int
nlocal
()
const
{
return
blockDim
.
x
;
// NOLINT
return
compute_local_size
()
;
// NOLINT
}
#endif
template
<
class
N
,
class
Stride
>
...
...
@@ -63,6 +108,7 @@ struct index
template
<
class
F
,
class
N
,
class
Stride
>
static
constexpr
void
for_stride
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
{
MIGRAPHX_ASSERT
(
start
<
stride
);
if
constexpr
(
not
is_integral
<
N
>
{}
and
not
is_integral
<
Stride
>
{}
and
max_stride_iterations
(
n
,
stride
)
==
1
)
{
...
...
test/verify/test_conv_group_add.cpp
0 → 100644
View file @
351fde4d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_conv_group_add
:
verify_program
<
test_conv_group_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
1
,
68
,
28
,
28
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
w
=
mm
->
add_parameter
(
"w"
,
{
migraphx
::
shape
::
float_type
,
{
68
,
17
,
1
,
1
}});
auto
b
=
mm
->
add_parameter
(
"b"
,
{
migraphx
::
shape
::
float_type
,
{
68
}});
auto
conv
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"group"
,
4
}}),
x
,
w
);
auto
bb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
68
,
28
,
28
}}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv
,
bb
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment