Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f2c9c70b
Commit
f2c9c70b
authored
Jul 01, 2022
by
Paul
Browse files
Merge branch 'jit-improve' into jit-layernorm-unroll
parents
6a8c231e
6deee23b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
14 deletions
+38
-14
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+36
-13
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
+2
-1
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
f2c9c70b
...
...
@@ -27,6 +27,7 @@
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace
migraphx
{
...
...
@@ -53,29 +54,51 @@ struct index
return
blockDim
.
x
;
// NOLINT
}
#endif
template
<
class
N
,
class
Stride
>
static
constexpr
auto
max_stride_iterations
(
N
n
,
Stride
stride
)
{
return
(
n
-
_c
<
1
>
)
/
stride
+
_c
<
1
>
;
}
template
<
class
F
>
__device__
void
global_stride
(
index_int
n
,
F
f
)
const
template
<
class
F
,
class
N
,
class
Stride
>
static
constexpr
void
for_stride
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
{
const
auto
stride
=
nglobal
();
for
(
in
de
x
_i
nt
i
=
global
;
i
<
n
;
i
+=
stride
)
if
const
expr
(
not
is_integral
<
N
>
{}
and
not
is_integral
<
Stride
>
{}
and
max_stri
de_i
terations
(
n
,
stride
)
==
1
)
{
f
(
i
);
if
constexpr
(
stride
>
n
)
{
if
(
start
<
n
)
f
(
start
);
}
else
{
f
(
start
);
}
}
else
{
for
(
index_int
i
=
start
;
i
<
n
;
i
+=
stride
)
{
f
(
i
);
}
}
}
template
<
class
F
>
__device__
void
lo
c
al_stride
(
index_int
n
,
F
f
)
const
template
<
class
F
,
class
N
>
__device__
void
g
lo
b
al_stride
(
N
n
,
F
f
)
const
{
const
auto
stride
=
nlocal
();
for
(
index_int
i
=
local
;
i
<
n
;
i
+=
stride
)
{
f
(
i
);
}
for_stride
(
global
,
n
,
nglobal
(),
f
);
}
template
<
class
F
,
class
N
>
__device__
void
local_stride
(
N
n
,
F
f
)
const
{
for_stride
(
local
,
n
,
nlocal
(),
f
);
}
};
inline
__device__
index
make_index
()
inline
__device__
__attribute__
((
const
))
index
make_index
()
{
return
index
{
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
,
threadIdx
.
x
,
blockIdx
.
x
};
// NOLINT
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
View file @
f2c9c70b
...
...
@@ -186,7 +186,8 @@ __device__ auto auto_preload(index idx)
{
return
make_transform
([
=
](
auto
f
,
auto
...
xs
)
{
auto
invoke
=
[
=
](
auto
...
ys
)
{
__syncthreads
();
if
constexpr
((
Bs
or
...))
__syncthreads
();
f
(
ys
...);
};
join
(
invoke
,
preload_copy
<
Bs
>
(
idx
,
xs
)...);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment