Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d8a2ed68
"...tutorials/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "fbee0df1205ec0f9557c24398526510b92da306f"
Commit
d8a2ed68
authored
Oct 23, 2022
by
Paul
Browse files
Add missing file
parent
34c4c24a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
0 deletions
+68
-0
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
...ets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
+68
-0
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
0 → 100644
View file @
d8a2ed68
#ifndef MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/index.hpp>
namespace
migraphx
{
template
<
class
Shape
>
constexpr
auto
gemm_get_batches
()
{
constexpr
auto
lens
=
Shape
{}.
lens
;
constexpr
auto
strides
=
Shape
{}.
strides
;
constexpr
auto
new_lens
=
sequence
(
lens
.
size
()
-
_c
<
2
>
,
[](
auto
...
is
)
{
return
make_const_array
(
_c
<
lens
[
is
]
>
...);
});
constexpr
auto
new_strides
=
sequence
(
strides
.
size
()
-
_c
<
2
>
,
[](
auto
...
is
)
{
return
make_const_array
(
_c
<
strides
[
is
]
>
...);
});
return
make_shape
(
new_lens
,
new_strides
);
}
template
<
class
Shape
>
constexpr
auto
gemm_get_matrix
()
{
constexpr
auto
lens
=
Shape
{}.
lens
;
constexpr
auto
strides
=
Shape
{}.
strides
;
constexpr
auto
m
=
lens
.
size
()
-
_c
<
2
>
;
constexpr
auto
n
=
lens
.
size
()
-
_c
<
1
>
;
constexpr
auto
new_lens
=
make_const_array
(
_c
<
lens
[
m
]
>
,
_c
<
lens
[
n
]
>
);
constexpr
auto
new_strides
=
make_const_array
(
_c
<
strides
[
m
]
>
,
_c
<
strides
[
n
]
>
);
return
make_shape
(
new_lens
,
new_strides
);
}
template
<
class
Tensor
,
class
T
>
constexpr
auto
gemm_batch_slice
(
Tensor
t
,
T
i
)
{
constexpr
auto
batch
=
gemm_get_batches
<
Tensor
>
();
constexpr
auto
matrix
=
gemm_get_matrix
<
Tensor
>
();
return
make_tensor_view
(
t
.
data
()
+
matrix
.
index
(
i
),
matrix
);
}
template
<
class
BlocksPerBatch
,
class
T
,
class
...
Ts
>
constexpr
auto
gemm_batch_args
(
index
idx
,
BlocksPerBatch
bpb
,
T
x
,
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
// All tensors should have the same rank
static_assert
((
true
and
...
and
(
get_shape_c
<
T
>
{}.
lens
.
size
()
==
get_shape_c
<
Ts
>
{}.
lens
.
size
())));
if
constexpr
(
get_shape_c
<
T
>
{}.
lens
.
size
()
>
2
)
{
// Get the first batch since all batches should have the same number of elements
constexpr
auto
batch
=
gemm_get_batches
<
T
>
();
static_assert
((
true
and
...
and
(
batch
.
elements
()
==
gemm_get_batches
<
Ts
>
().
elements
())));
idx
.
group_stride
(
bpb
*
batch
.
elements
(),
[
&
](
auto
gidx
)
{
constexpr
auto
batch_idx
=
gidx
/
bpb
;
f
(
gemm_batch_slice
(
x
,
batch_idx
),
gemm_batch_slice
(
xs
,
batch_idx
)...);
});
}
else
{
f
(
x
,
xs
...);
}
};
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment