Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ed22ba4f
Commit
ed22ba4f
authored
Oct 23, 2022
by
Paul
Browse files
Format
parent
711aaed9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
21 deletions
+21
-21
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
...ets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
+21
-21
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
View file @
ed22ba4f
...
@@ -7,51 +7,51 @@
...
@@ -7,51 +7,51 @@
namespace
migraphx
{
namespace
migraphx
{
template
<
class
Tensor
>
template
<
class
Tensor
>
constexpr
auto
gemm_get_batches
()
constexpr
auto
gemm_get_batches
()
{
{
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
new_lens
=
sequence
(
lens
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
constexpr
auto
new_lens
=
sequence
(
return
make_const_array
(
_c
<
lens
[
is
]
>
...);
lens
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
return
make_const_array
(
_c
<
lens
[
is
]
>
...);
});
});
constexpr
auto
new_strides
=
sequence
(
constexpr
auto
new_strides
=
sequence
(
strides
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
strides
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
return
make_const_array
(
_c
<
strides
[
is
]
>
...);
});
return
make_const_array
(
_c
<
strides
[
is
]
>
...);
});
return
make_shape
(
new_lens
,
new_strides
);
return
make_shape
(
new_lens
,
new_strides
);
}
}
template
<
class
Tensor
>
template
<
class
Tensor
>
constexpr
auto
gemm_get_matrix
()
constexpr
auto
gemm_get_matrix
()
{
{
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
m
=
lens
.
size
()
-
_c
<
2
>
;
constexpr
auto
m
=
lens
.
size
()
-
_c
<
2
>
;
constexpr
auto
n
=
lens
.
size
()
-
_c
<
1
>
;
constexpr
auto
n
=
lens
.
size
()
-
_c
<
1
>
;
constexpr
auto
new_lens
=
make_const_array
(
_c
<
lens
[
m
]
>
,
_c
<
lens
[
n
]
>
);
constexpr
auto
new_lens
=
make_const_array
(
_c
<
lens
[
m
]
>
,
_c
<
lens
[
n
]
>
);
constexpr
auto
new_strides
=
make_const_array
(
_c
<
strides
[
m
]
>
,
_c
<
strides
[
n
]
>
);
constexpr
auto
new_strides
=
make_const_array
(
_c
<
strides
[
m
]
>
,
_c
<
strides
[
n
]
>
);
return
make_shape
(
new_lens
,
new_strides
);
return
make_shape
(
new_lens
,
new_strides
);
}
}
template
<
class
Tensor
,
class
T
>
template
<
class
Tensor
,
class
T
>
constexpr
auto
gemm_batch_slice
(
Tensor
t
,
T
i
)
constexpr
auto
gemm_batch_slice
(
Tensor
t
,
T
i
)
{
{
constexpr
auto
batch
=
gemm_get_batches
<
Tensor
>
();
constexpr
auto
batch
=
gemm_get_batches
<
Tensor
>
();
constexpr
auto
matrix
=
gemm_get_matrix
<
Tensor
>
();
constexpr
auto
matrix
=
gemm_get_matrix
<
Tensor
>
();
return
make_tensor_view
(
t
.
data
()
+
batch
.
index
(
i
),
matrix
);
return
make_tensor_view
(
t
.
data
()
+
batch
.
index
(
i
),
matrix
);
}
}
template
<
class
BlocksPerBatch
,
class
T
,
class
...
Ts
>
template
<
class
BlocksPerBatch
,
class
T
,
class
...
Ts
>
constexpr
auto
gemm_batch_args
(
index
idx
,
BlocksPerBatch
bpb
,
T
x
,
Ts
...
xs
)
constexpr
auto
gemm_batch_args
(
index
idx
,
BlocksPerBatch
bpb
,
T
x
,
Ts
...
xs
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
// All tensors should have the same rank
// All tensors should have the same rank
static_assert
((
true
and
...
and
(
get_shape_c
<
T
>
{}.
lens
.
size
()
==
get_shape_c
<
Ts
>
{}.
lens
.
size
())));
static_assert
(
(
true
and
...
and
(
get_shape_c
<
T
>
{}.
lens
.
size
()
==
get_shape_c
<
Ts
>
{}.
lens
.
size
())));
if
constexpr
(
get_shape_c
<
T
>
{}.
lens
.
size
()
>
2
)
if
constexpr
(
get_shape_c
<
T
>
{}.
lens
.
size
()
>
2
)
{
{
// Get the first batch since all batches should have the same number of elements
// Get the first batch since all batches should have the same number of elements
constexpr
auto
batch
=
gemm_get_batches
<
T
>
();
constexpr
auto
batch
=
gemm_get_batches
<
T
>
();
static_assert
((
true
and
...
and
(
batch
.
elements
()
==
gemm_get_batches
<
Ts
>
().
elements
())));
static_assert
(
(
true
and
...
and
(
batch
.
elements
()
==
gemm_get_batches
<
Ts
>
().
elements
())));
idx
.
group_stride
(
bpb
*
batch
.
elements
(),
[
&
](
auto
gidx
)
{
idx
.
group_stride
(
bpb
*
batch
.
elements
(),
[
&
](
auto
gidx
)
{
const
auto
batch_idx
=
gidx
/
bpb
;
const
auto
batch_idx
=
gidx
/
bpb
;
f
(
gemm_batch_slice
(
x
,
batch_idx
),
gemm_batch_slice
(
xs
,
batch_idx
)...);
f
(
gemm_batch_slice
(
x
,
batch_idx
),
gemm_batch_slice
(
xs
,
batch_idx
)...);
...
@@ -59,7 +59,7 @@ constexpr auto gemm_batch_args(index idx, BlocksPerBatch bpb, T x, Ts... xs)
...
@@ -59,7 +59,7 @@ constexpr auto gemm_batch_args(index idx, BlocksPerBatch bpb, T x, Ts... xs)
}
}
else
else
{
{
f
(
x
,
xs
...);
f
(
x
,
xs
...);
}
}
};
};
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment