Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
127393f4
Commit
127393f4
authored
Sep 09, 2022
by
turneram
Browse files
Call gemm from kernel
parent
9d12476e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
329 additions
and
134 deletions
+329
-134
src/include/migraphx/op/ck_elementwise.hpp
src/include/migraphx/op/ck_elementwise.hpp
+77
-0
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+12
-2
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+238
-130
test/verify/0ck_gemm_test.cpp
test/verify/0ck_gemm_test.cpp
+2
-2
test/verify/1ck_elementwise_test.cpp
test/verify/1ck_elementwise_test.cpp
+0
-0
No files found.
src/include/migraphx/op/ck_elementwise.hpp
0 → 100644
View file @
127393f4
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_ELEMENTWISE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
ck_elementwise
{
std
::
string
name
()
const
{
return
"ck_elementwise"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_dims
();
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
==
s1
and
s0
.
packed
())
{
return
s0
;
}
else
if
(
s0
.
packed
()
!=
s1
.
packed
())
{
return
s0
.
packed
()
?
s0
:
s1
;
}
else
if
(
s0
.
broadcasted
()
!=
s1
.
broadcasted
())
{
return
s0
.
broadcasted
()
?
s1
.
with_lens
(
s0
.
lens
())
:
s0
.
with_lens
(
s0
.
lens
());
}
else
{
return
{
s0
.
type
(),
s0
.
lens
()};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
const
auto
i
)
{
output
[
i
]
=
input1
[
i
]
+
input2
[
i
];
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/jit/ck_gemm.cpp
View file @
127393f4
...
...
@@ -47,14 +47,24 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{
make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
ck_gemm(xs...);
hipDeviceProp_t hdp{};
printf("Shared mem: %i\n", int(hdp.sharedMemPerBlock));
// make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
// ck_gemm(xs...);
// });
make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
__shared__ void* p_shared_block[(a_t.get_shape().elements()/* + b_t.get_shape().elements() */) * 2];
make_tensors()(p_shared_block)([&](auto p_t) {
ck_gemm(a_t, b_t, c_t, p_t);
});
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
127393f4
...
...
@@ -26,6 +26,8 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
...
@@ -39,143 +41,249 @@
namespace
migraphx
{
// static constexpr auto I0 = Number<0>{};
// static constexpr auto I1 = Number<1>{};
// static constexpr auto I2 = Number<2>{};
// static constexpr auto I3 = Number<3>{};
// static constexpr auto I4 = Number<4>{};
// static constexpr auto I5 = Number<5>{};
// static constexpr auto K1Number = Number<1>{};
// static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
// {
// assert(K % K1 == 0);
// const index_t K0 = K / K1;
// const auto a_grid_desc_m_k = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// return transform_tensor_descriptor(
// a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_right_pad_transform(M, PadM)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_pass_through_transform(M)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// }
// static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
// {
// assert(K % K1 == 0);
// const index_t K0 = K / K1;
// const auto b_grid_desc_k_n = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// return transform_tensor_descriptor(
// b_grid_desc_k_n,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_right_pad_transform(N, PadN)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// b_grid_desc_k_n,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// }
// static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
// {
// const auto c_grid_desc_m_n = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// return transform_tensor_descriptor(
// c_grid_desc_m_n,
// make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// c_grid_desc_m_n,
// make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// }
template
<
class
T
,
class
U
,
class
V
>
__device__
void
ck_gemm
(
const
T
&
/* a_t */
,
const
U
&
/* b_t */
,
const
V
&
/* c_t */
)
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
static
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
static
constexpr
auto
I4
=
ck
::
Number
<
4
>
{};
static
constexpr
auto
I5
=
ck
::
Number
<
5
>
{};
static
constexpr
ck
::
index_t
K1
=
1
;
static
constexpr
auto
K1Number
=
ck
::
Number
<
K1
>
{};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// Values hard-coded by CK
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
BlockSize
=
256
;
static
constexpr
ck
::
index_t
K0PerBlock
=
16
;
static
constexpr
ck
::
index_t
M1PerThread
=
4
;
static
constexpr
ck
::
index_t
N1PerThread
=
4
;
static
constexpr
ck
::
index_t
KPerThread
=
1
;
using
M1N1ThreadClusterM1Xs
=
S
<
8
,
2
>
;
using
M1N1ThreadClusterN1Xs
=
S
<
8
,
2
>
;
using
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
=
S
<
2
,
1
,
4
,
1
>
;
using
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
=
S
<
8
,
1
,
32
,
1
>
;
using
ABlockTransferThreadClusterArrangeOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferSrcAccessOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
ABlockTransferSrcVectorTensorContiguousDimOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
=
S
<
2
,
1
,
4
,
1
>
;
using
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
=
S
<
8
,
1
,
32
,
1
>
;
using
BBlockTransferThreadClusterArrangeOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferSrcAccessOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
BBlockTransferSrcVectorTensorContiguousDimOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
CThreadTransferSrcDstAccessOrder
=
S
<
0
,
1
,
2
,
3
,
4
,
5
>
;
static
constexpr
ck
::
index_t
CThreadTransferSrcDstVectorDim
=
5
;
static
constexpr
ck
::
index_t
CThreadTransferDstScalarPerVector
=
4
;
static
constexpr
auto
MakeAGridDescriptor_K0_M_K1
(
ck
::
index_t
M
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
ck
::
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
K
),
ck
::
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
K
),
ck
::
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_right_pad_transform
(
M
,
PadM
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
static
constexpr
auto
MakeBGridDescriptor_K0_N_K1
(
ck
::
index_t
K
,
ck
::
index_t
N
,
ck
::
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
ck
::
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
K
,
N
),
ck
::
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
K
,
N
),
ck
::
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_right_pad_transform
(
N
,
PadN
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
static
constexpr
auto
MakeCGridDescriptor_M_N
(
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
N
),
ck
::
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
N
),
ck
::
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
M
,
PadM
),
ck
::
make_right_pad_transform
(
N
,
PadN
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
M
),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
template
<
class
T
,
class
U
,
class
V
,
class
W
>
__device__
void
ck_gemm
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
,
const
W
&
p_t
)
{
constexpr
auto
alens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
m
=
alens
[
0
];
constexpr
auto
k
=
alens
[
1
];
constexpr
auto
a
lens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
n
=
a
lens
[
1
];
constexpr
auto
b
lens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
n
=
b
lens
[
1
];
constexpr
auto
astrides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
auto
as
=
astrides
[
1
];
constexpr
auto
as
=
astrides
[
0
];
constexpr
auto
bstrides
=
get_shape_c
<
U
>
{}.
strides
;
constexpr
auto
bs
=
bstrides
[
1
];
constexpr
auto
bs
=
bstrides
[
0
];
constexpr
auto
cstrides
=
get_shape_c
<
V
>
{}.
strides
;
constexpr
auto
cs
=
cstrides
[
1
];
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
constexpr
auto
cs
=
cstrides
[
0
];
auto
idx
=
make_index
();
if
(
idx
.
global
==
0
)
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
auto
a_grid_desc_k0_m_k1
=
MakeAGridDescriptor_K0_M_K1
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
as
));
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
bs
));
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
cs
));
using
GridwiseGemm
=
ck
::
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
AccDataType
,
CDataType
,
ck
::
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
auto
a_grid_desc_k0_m0_m1_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1
);
auto
b_grid_desc_k0_n0_n1_k1
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_k0_n_k1
);
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n
);
auto
block_2_ctile_map
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
// namespace migraphx
...
...
test/verify/0ck_gemm_test.cpp
View file @
127393f4
...
...
@@ -33,8 +33,8 @@ struct ck_gemm : verify_program<ck_gemm>
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
0
,
2
0
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
2
0
,
2
0
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
28
,
2
56
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
2
56
,
2
56
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
...
...
test/verify/
0
ck_elementwise_test.cpp
→
test/verify/
1
ck_elementwise_test.cpp
View file @
127393f4
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment