Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
127393f4
Commit
127393f4
authored
Sep 09, 2022
by
turneram
Browse files
Call gemm from kernel
parent
9d12476e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
329 additions
and
134 deletions
+329
-134
src/include/migraphx/op/ck_elementwise.hpp
src/include/migraphx/op/ck_elementwise.hpp
+77
-0
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+12
-2
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+238
-130
test/verify/0ck_gemm_test.cpp
test/verify/0ck_gemm_test.cpp
+2
-2
test/verify/1ck_elementwise_test.cpp
test/verify/1ck_elementwise_test.cpp
+0
-0
No files found.
src/include/migraphx/op/ck_elementwise.hpp
0 → 100644
View file @
127393f4
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_ELEMENTWISE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
ck_elementwise
{
std
::
string
name
()
const
{
return
"ck_elementwise"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_dims
();
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
==
s1
and
s0
.
packed
())
{
return
s0
;
}
else
if
(
s0
.
packed
()
!=
s1
.
packed
())
{
return
s0
.
packed
()
?
s0
:
s1
;
}
else
if
(
s0
.
broadcasted
()
!=
s1
.
broadcasted
())
{
return
s0
.
broadcasted
()
?
s1
.
with_lens
(
s0
.
lens
())
:
s0
.
with_lens
(
s0
.
lens
());
}
else
{
return
{
s0
.
type
(),
s0
.
lens
()};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
const
auto
i
)
{
output
[
i
]
=
input1
[
i
]
+
input2
[
i
];
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/jit/ck_gemm.cpp
View file @
127393f4
...
...
@@ -47,14 +47,24 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{
make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
ck_gemm(xs...);
hipDeviceProp_t hdp{};
printf("Shared mem: %i\n", int(hdp.sharedMemPerBlock));
// make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
// ck_gemm(xs...);
// });
make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
__shared__ void* p_shared_block[(a_t.get_shape().elements()/* + b_t.get_shape().elements() */) * 2];
make_tensors()(p_shared_block)([&](auto p_t) {
ck_gemm(a_t, b_t, c_t, p_t);
});
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
127393f4
...
...
@@ -26,6 +26,8 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
...
@@ -39,143 +41,249 @@
namespace
migraphx
{
// static constexpr auto I0 = Number<0>{};
// static constexpr auto I1 = Number<1>{};
// static constexpr auto I2 = Number<2>{};
// static constexpr auto I3 = Number<3>{};
// static constexpr auto I4 = Number<4>{};
// static constexpr auto I5 = Number<5>{};
// static constexpr auto K1Number = Number<1>{};
// static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
// {
// assert(K % K1 == 0);
// const index_t K0 = K / K1;
// const auto a_grid_desc_m_k = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// return transform_tensor_descriptor(
// a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_right_pad_transform(M, PadM)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_pass_through_transform(M)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// }
// static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
// {
// assert(K % K1 == 0);
// const index_t K0 = K / K1;
// const auto b_grid_desc_k_n = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// return transform_tensor_descriptor(
// b_grid_desc_k_n,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_right_pad_transform(N, PadN)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// b_grid_desc_k_n,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
// make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// }
// }
// static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
// {
// const auto c_grid_desc_m_n = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
// }
// }();
// if constexpr(GemmSpec == GemmSpecialization::MNPadding)
// {
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// return transform_tensor_descriptor(
// c_grid_desc_m_n,
// make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// else
// {
// return transform_tensor_descriptor(
// c_grid_desc_m_n,
// make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// }
template
<
class
T
,
class
U
,
class
V
>
__device__
void
ck_gemm
(
const
T
&
/* a_t */
,
const
U
&
/* b_t */
,
const
V
&
/* c_t */
)
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
static
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
static
constexpr
auto
I4
=
ck
::
Number
<
4
>
{};
static
constexpr
auto
I5
=
ck
::
Number
<
5
>
{};
static
constexpr
ck
::
index_t
K1
=
1
;
static
constexpr
auto
K1Number
=
ck
::
Number
<
K1
>
{};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// Values hard-coded by CK
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
BlockSize
=
256
;
static
constexpr
ck
::
index_t
K0PerBlock
=
16
;
static
constexpr
ck
::
index_t
M1PerThread
=
4
;
static
constexpr
ck
::
index_t
N1PerThread
=
4
;
static
constexpr
ck
::
index_t
KPerThread
=
1
;
using
M1N1ThreadClusterM1Xs
=
S
<
8
,
2
>
;
using
M1N1ThreadClusterN1Xs
=
S
<
8
,
2
>
;
using
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
=
S
<
2
,
1
,
4
,
1
>
;
using
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
=
S
<
8
,
1
,
32
,
1
>
;
using
ABlockTransferThreadClusterArrangeOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferSrcAccessOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
ABlockTransferSrcVectorTensorContiguousDimOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
=
S
<
2
,
1
,
4
,
1
>
;
using
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
=
S
<
8
,
1
,
32
,
1
>
;
using
BBlockTransferThreadClusterArrangeOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferSrcAccessOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
BBlockTransferSrcVectorTensorContiguousDimOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
CThreadTransferSrcDstAccessOrder
=
S
<
0
,
1
,
2
,
3
,
4
,
5
>
;
static
constexpr
ck
::
index_t
CThreadTransferSrcDstVectorDim
=
5
;
static
constexpr
ck
::
index_t
CThreadTransferDstScalarPerVector
=
4
;
static
constexpr
auto
MakeAGridDescriptor_K0_M_K1
(
ck
::
index_t
M
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
ck
::
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
K
),
ck
::
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
K
),
ck
::
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_right_pad_transform
(
M
,
PadM
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
static
constexpr
auto
MakeBGridDescriptor_K0_N_K1
(
ck
::
index_t
K
,
ck
::
index_t
N
,
ck
::
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
ck
::
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
K
,
N
),
ck
::
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
K
,
N
),
ck
::
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_right_pad_transform
(
N
,
PadN
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
static
constexpr
auto
MakeCGridDescriptor_M_N
(
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
N
),
ck
::
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
N
),
ck
::
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
M
,
PadM
),
ck
::
make_right_pad_transform
(
N
,
PadN
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
M
),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
template
<
class
T
,
class
U
,
class
V
,
class
W
>
__device__
void
ck_gemm
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
,
const
W
&
p_t
)
{
constexpr
auto
alens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
m
=
alens
[
0
];
constexpr
auto
k
=
alens
[
1
];
constexpr
auto
a
lens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
n
=
a
lens
[
1
];
constexpr
auto
b
lens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
n
=
b
lens
[
1
];
constexpr
auto
astrides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
auto
as
=
astrides
[
1
];
constexpr
auto
as
=
astrides
[
0
];
constexpr
auto
bstrides
=
get_shape_c
<
U
>
{}.
strides
;
constexpr
auto
bs
=
bstrides
[
1
];
constexpr
auto
bs
=
bstrides
[
0
];
constexpr
auto
cstrides
=
get_shape_c
<
V
>
{}.
strides
;
constexpr
auto
cs
=
cstrides
[
1
];
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
constexpr
auto
cs
=
cstrides
[
0
];
auto
idx
=
make_index
();
if
(
idx
.
global
==
0
)
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
auto
a_grid_desc_k0_m_k1
=
MakeAGridDescriptor_K0_M_K1
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
as
));
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
bs
));
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
cs
));
using
GridwiseGemm
=
ck
::
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
AccDataType
,
CDataType
,
ck
::
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
auto
a_grid_desc_k0_m0_m1_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1
);
auto
b_grid_desc_k0_n0_n1_k1
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_k0_n_k1
);
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n
);
auto
block_2_ctile_map
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
// namespace migraphx
...
...
test/verify/0ck_gemm_test.cpp
View file @
127393f4
...
...
@@ -33,8 +33,8 @@ struct ck_gemm : verify_program<ck_gemm>
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
0
,
2
0
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
2
0
,
2
0
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
28
,
2
56
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
2
56
,
2
56
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
...
...
test/verify/
0
ck_elementwise_test.cpp
→
test/verify/
1
ck_elementwise_test.cpp
View file @
127393f4
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment