Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
711aaed9
Commit
711aaed9
authored
Oct 23, 2022
by
Paul
Browse files
Compile fixes
parent
d8a2ed68
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
10 deletions
+10
-10
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
...ets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
+10
-10
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
View file @
711aaed9
...
...
@@ -7,25 +7,25 @@
namespace
migraphx
{
template
<
class
Shape
>
template
<
class
Tensor
>
constexpr
auto
gemm_get_batches
()
{
constexpr
auto
lens
=
Shape
{}.
lens
;
constexpr
auto
strides
=
Shape
{}.
strides
;
constexpr
auto
new_lens
=
sequence
(
lens
.
size
()
-
_c
<
2
>
,
[](
auto
...
is
)
{
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
new_lens
=
sequence
(
lens
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
return
make_const_array
(
_c
<
lens
[
is
]
>
...);
});
constexpr
auto
new_strides
=
sequence
(
strides
.
size
()
-
_c
<
2
>
,
[](
auto
...
is
)
{
constexpr
auto
new_strides
=
sequence
(
strides
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
return
make_const_array
(
_c
<
strides
[
is
]
>
...);
});
return
make_shape
(
new_lens
,
new_strides
);
}
template
<
class
Shape
>
template
<
class
Tensor
>
constexpr
auto
gemm_get_matrix
()
{
constexpr
auto
lens
=
Shape
{}.
lens
;
constexpr
auto
strides
=
Shape
{}.
strides
;
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
m
=
lens
.
size
()
-
_c
<
2
>
;
constexpr
auto
n
=
lens
.
size
()
-
_c
<
1
>
;
constexpr
auto
new_lens
=
make_const_array
(
_c
<
lens
[
m
]
>
,
_c
<
lens
[
n
]
>
);
...
...
@@ -38,7 +38,7 @@ constexpr auto gemm_batch_slice(Tensor t, T i)
{
constexpr
auto
batch
=
gemm_get_batches
<
Tensor
>
();
constexpr
auto
matrix
=
gemm_get_matrix
<
Tensor
>
();
return
make_tensor_view
(
t
.
data
()
+
m
at
rix
.
index
(
i
),
matrix
);
return
make_tensor_view
(
t
.
data
()
+
b
at
ch
.
index
(
i
),
matrix
);
}
template
<
class
BlocksPerBatch
,
class
T
,
class
...
Ts
>
...
...
@@ -53,7 +53,7 @@ constexpr auto gemm_batch_args(index idx, BlocksPerBatch bpb, T x, Ts... xs)
constexpr
auto
batch
=
gemm_get_batches
<
T
>
();
static_assert
((
true
and
...
and
(
batch
.
elements
()
==
gemm_get_batches
<
Ts
>
().
elements
())));
idx
.
group_stride
(
bpb
*
batch
.
elements
(),
[
&
](
auto
gidx
)
{
const
expr
auto
batch_idx
=
gidx
/
bpb
;
const
auto
batch_idx
=
gidx
/
bpb
;
f
(
gemm_batch_slice
(
x
,
batch_idx
),
gemm_batch_slice
(
xs
,
batch_idx
)...);
});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment