Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ddb0c230
Commit
ddb0c230
authored
Sep 09, 2022
by
turneram
Browse files
Formatting
parent
127393f4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
90 additions
and
83 deletions
+90
-83
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+90
-83
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
ddb0c230
...
...
@@ -48,12 +48,11 @@ static constexpr auto I3 = ck::Number<3>{};
static
constexpr
auto
I4
=
ck
::
Number
<
4
>
{};
static
constexpr
auto
I5
=
ck
::
Number
<
5
>
{};
static
constexpr
ck
::
index_t
K1
=
1
;
static
constexpr
auto
K1Number
=
ck
::
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
ck
::
Number
<
K1
>
{};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
...
...
@@ -69,34 +68,32 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// Values hard-coded by CK
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
BlockSize
=
256
;
static
constexpr
ck
::
index_t
K0PerBlock
=
16
;
static
constexpr
ck
::
index_t
M1PerThread
=
4
;
static
constexpr
ck
::
index_t
N1PerThread
=
4
;
static
constexpr
ck
::
index_t
KPerThread
=
1
;
using
M1N1ThreadClusterM1Xs
=
S
<
8
,
2
>
;
using
M1N1ThreadClusterN1Xs
=
S
<
8
,
2
>
;
using
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
=
S
<
2
,
1
,
4
,
1
>
;
using
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
=
S
<
8
,
1
,
32
,
1
>
;
using
ABlockTransferThreadClusterArrangeOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferSrcAccessOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
ABlockTransferSrcVectorTensorContiguousDimOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
=
S
<
2
,
1
,
4
,
1
>
;
using
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
=
S
<
8
,
1
,
32
,
1
>
;
using
BBlockTransferThreadClusterArrangeOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferSrcAccessOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
BBlockTransferSrcVectorTensorContiguousDimOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
CThreadTransferSrcDstAccessOrder
=
S
<
0
,
1
,
2
,
3
,
4
,
5
>
;
static
constexpr
ck
::
index_t
CThreadTransferSrcDstVectorDim
=
5
;
static
constexpr
ck
::
index_t
CThreadTransferDstScalarPerVector
=
4
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
BlockSize
=
256
;
static
constexpr
ck
::
index_t
K0PerBlock
=
16
;
static
constexpr
ck
::
index_t
M1PerThread
=
4
;
static
constexpr
ck
::
index_t
N1PerThread
=
4
;
static
constexpr
ck
::
index_t
KPerThread
=
1
;
using
M1N1ThreadClusterM1Xs
=
S
<
8
,
2
>
;
using
M1N1ThreadClusterN1Xs
=
S
<
8
,
2
>
;
using
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
=
S
<
2
,
1
,
4
,
1
>
;
using
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
=
S
<
8
,
1
,
32
,
1
>
;
using
ABlockTransferThreadClusterArrangeOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferSrcAccessOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
ABlockTransferSrcVectorTensorContiguousDimOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
=
S
<
2
,
1
,
4
,
1
>
;
using
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
=
S
<
8
,
1
,
32
,
1
>
;
using
BBlockTransferThreadClusterArrangeOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferSrcAccessOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
BBlockTransferSrcVectorTensorContiguousDimOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
CThreadTransferSrcDstAccessOrder
=
S
<
0
,
1
,
2
,
3
,
4
,
5
>
;
static
constexpr
ck
::
index_t
CThreadTransferSrcDstVectorDim
=
5
;
static
constexpr
ck
::
index_t
CThreadTransferDstScalarPerVector
=
4
;
static
constexpr
auto
MakeAGridDescriptor_K0_M_K1
(
ck
::
index_t
M
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
)
{
...
...
@@ -122,7 +119,7 @@ static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K,
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_right_pad_transform
(
M
,
PadM
)),
ck
::
make_right_pad_transform
(
M
,
PadM
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
...
...
@@ -131,7 +128,7 @@ static constexpr auto MakeAGridDescriptor_K0_M_K1(ck::index_t M, ck::index_t K,
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
...
...
@@ -161,7 +158,7 @@ static constexpr auto MakeBGridDescriptor_K0_N_K1(ck::index_t K, ck::index_t N,
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_right_pad_transform
(
N
,
PadN
)),
ck
::
make_right_pad_transform
(
N
,
PadN
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
...
...
@@ -170,7 +167,7 @@ static constexpr auto MakeBGridDescriptor_K0_N_K1(ck::index_t K, ck::index_t N,
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
...
...
@@ -194,11 +191,11 @@ static constexpr auto MakeCGridDescriptor_M_N(ck::index_t M, ck::index_t N, ck::
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
M
,
PadM
),
ck
::
make_right_pad_transform
(
N
,
PadN
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
M
,
PadM
)
,
ck
::
make_right_pad_transform
(
N
,
PadN
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
...
...
@@ -229,48 +226,50 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
constexpr
auto
bs
=
bstrides
[
0
];
constexpr
auto
cstrides
=
get_shape_c
<
V
>
{}.
strides
;
constexpr
auto
cs
=
cstrides
[
0
];
auto
idx
=
make_index
();
if
(
idx
.
global
==
0
)
auto
idx
=
make_index
();
if
(
idx
.
global
==
0
)
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
auto
a_grid_desc_k0_m_k1
=
MakeAGridDescriptor_K0_M_K1
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
as
));
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
bs
));
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
cs
));
auto
a_grid_desc_k0_m_k1
=
MakeAGridDescriptor_K0_M_K1
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
as
));
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
bs
));
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
cs
));
using
GridwiseGemm
=
ck
::
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
AccDataType
,
CDataType
,
ck
::
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
ADataType
,
AccDataType
,
CDataType
,
ck
::
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
auto
a_grid_desc_k0_m0_m1_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1
);
...
...
@@ -280,10 +279,18 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n
);
auto
block_2_ctile_map
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
// namespace migraphx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment