Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
de9f5bed
Commit
de9f5bed
authored
May 25, 2021
by
Jing Zhang
Browse files
add kpack into xldops_gemm and blockwise_gemm
parent
776721ab
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
107 additions
and
81 deletions
+107
-81
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+81
-59
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+19
-15
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+6
-6
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
de9f5bed
...
...
@@ -15,7 +15,7 @@ template <index_t BlockSize,
class
BBlockDesc
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KP
erWave
>
index_t
KP
ack
>
struct
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{
...
...
@@ -26,8 +26,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
float
,
MPerWave
,
NPerWave
,
KPerWave
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
M0
=
ABlockDesc
{}.
GetLength
(
I1
);
...
...
@@ -36,6 +34,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
float
,
MPerWave
,
NPerWave
,
KPack
>
{};
static
constexpr
index_t
MWaves
=
M1
/
MPerWave
;
static
constexpr
index_t
NWaves
=
N1
/
NPerWave
;
...
...
@@ -59,14 +59,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
)
*
xdlops_gemm
.
mfma_type
.
k_base
;
return
make_tuple
(
k_offset
,
0
,
m_offset
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
);
return
make_tuple
(
k_offset
,
0
,
m_offset
,
0
);
}
else
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
0
,
m_offset
);
return
make_tuple
(
k_offset
,
0
,
m_offset
,
0
);
}
}
...
...
@@ -81,14 +81,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
)
*
xdlops_gemm
.
mfma_type
.
k_base
;
return
make_tuple
(
k_offset
,
0
,
n_offset
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
);
return
make_tuple
(
k_offset
,
0
,
n_offset
,
0
);
}
else
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
0
,
n_offset
);
return
make_tuple
(
k_offset
,
0
,
n_offset
,
0
);
}
}
...
...
@@ -120,8 +120,19 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I3
)
==
BBlockDesc
{}.
GetLength
(
I3
),
"wrong! KPack dimension not consistent"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
KPack
==
BBlockDesc
{}.
GetLength
(
I3
),
"KPack is wrong!"
);
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
);
static_assert
(
KPerBlock
%
xdlops_gemm
.
KPerXdlops
==
0
,
"KPerBlock is wrong!"
);
static_assert
(
KPack
%
xdlops_gemm
.
mfma_type
.
k_base
==
0
,
"KPack is wrong!"
);
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
...
...
@@ -136,21 +147,21 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
);
static_for
<
0
,
KPerBlock
,
KPerWave
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k
)
{
// read A
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
make_tuple
(
k
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// read B
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
make_tuple
(
k
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
...
@@ -168,11 +179,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{}
,
Number
<
MRepeat
>
{},
I1
));
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
KPack
>
{}
));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{}
,
Number
<
NRepeat
>
{},
I1
));
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
KPack
>
{}
));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
...
...
@@ -181,20 +192,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
FloatA
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
KPerWave
,
MRepeat
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
1
,
Sequence
<
1
,
MRepeat
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// KPack,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
KPerWave
,
NRepeat
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
1
,
Sequence
<
1
,
NRepeat
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// KPack,
1
>
;
AThreadCopy
a_thread_copy_
;
...
...
@@ -208,7 +219,7 @@ template <index_t BlockSize,
class
BBlockDesc
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KP
erWave
>
index_t
KP
ack
>
struct
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{
...
...
@@ -219,7 +230,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
float
,
MPerWave
,
NPerWave
,
KP
erWave
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
float
,
MPerWave
,
NPerWave
,
KP
ack
>
{};
static
constexpr
index_t
WaveSize
=
64
;
...
...
@@ -252,14 +263,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
)
*
xdlops_gemm
.
mfma_type
.
k_base
;
return
make_tuple
(
k_offset
,
0
,
m_offset
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
);
return
make_tuple
(
k_offset
,
0
,
m_offset
,
0
);
}
else
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
0
,
m_offset
);
return
make_tuple
(
k_offset
,
0
,
m_offset
,
0
);
}
}
...
...
@@ -274,14 +285,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
)
*
xdlops_gemm
.
mfma_type
.
k_base
;
return
make_tuple
(
k_offset
,
0
,
n_offset
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
);
return
make_tuple
(
k_offset
,
0
,
n_offset
,
0
);
}
else
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
0
,
n_offset
);
return
make_tuple
(
k_offset
,
0
,
n_offset
,
0
);
}
}
...
...
@@ -313,8 +324,19 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I3
)
==
BBlockDesc
{}.
GetLength
(
I3
),
"wrong! KPack dimension not consistent"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
KPack
==
BBlockDesc
{}.
GetLength
(
I3
),
"KPack is wrong!"
);
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
);
static_assert
(
KPerBlock
%
xdlops_gemm
.
KPerXdlops
==
0
,
"KPerBlock is wrong!"
);
static_assert
(
KPack
%
xdlops_gemm
.
mfma_type
.
k_base
==
0
,
"KPack is wrong!"
);
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
...
...
@@ -331,34 +353,34 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
),
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
),
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
...
...
@@ -375,13 +397,13 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
0
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
static_for
<
KPerWave
,
KPerBlock
,
KPerWave
>
{}([
&
](
auto
k
)
{
static_for
<
xdlops_gemm
.
KPerXdlops
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
make_tuple
(
k
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
...
...
@@ -393,10 +415,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
make_tuple
(
k
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
...
...
@@ -408,18 +430,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
make_tuple
(
k
,
I1
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
make_tuple
(
k
,
I1
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
...
...
@@ -455,11 +477,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{}
,
Number
<
MRepeat
>
{},
I1
));
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
KPack
>
{}
));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{}
,
Number
<
NRepeat
>
{},
I1
));
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
KPack
>
{}
));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
...
...
@@ -468,20 +490,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
FloatA
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
KPerWave
,
1
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
1
,
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// KPack,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
KPerWave
,
1
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
1
,
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// KPack,
1
>
;
AThreadCopy
a_thread_copy_
;
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
de9f5bed
...
...
@@ -110,7 +110,7 @@ template <index_t BlockSize,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KP
erWave
,
index_t
KP
ack
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K_M
,
...
...
@@ -276,7 +276,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBloc
l
, NPerBlock] is in LDS
// b_mtx[KPerBloc
k
, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
...
...
@@ -285,31 +285,35 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
NPerBlock
%
(
NPerWave
*
NRepeat
)
==
0
,
"wrong!"
);
constexpr
auto
a_k_m0_m1_block_desc
=
transform_dynamic_tensor_descriptor
(
static_assert
(
KPerBlock
%
KPack
==
0
,
"KPerBlock is wrong!"
);
constexpr
auto
a_k0_m0_m1_k1_block_desc
=
transform_dynamic_tensor_descriptor
(
a_k_m_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MPerBlock
/
MRepeat
>
{}))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
KPerBlock
/
KPack
>
{},
Number
<
KPack
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MPerBlock
/
MRepeat
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
>
{}));
constexpr
auto
b_k_n0_n1_block_desc
=
transform_dynamic_tensor_descriptor
(
constexpr
auto
b_k
0
_n0_n1_
k1_
block_desc
=
transform_dynamic_tensor_descriptor
(
b_k_n_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NPerBlock
/
NRepeat
>
{}))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
KPerBlock
/
KPack
>
{},
Number
<
KPack
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NPerBlock
/
NRepeat
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
<
BlockSize
,
FloatAB
,
FloatAB
,
decltype
(
a_k_m0_m1_block_desc
),
decltype
(
b_k_n0_n1_block_desc
),
decltype
(
a_k
0
_m0_m1_
k1_
block_desc
),
decltype
(
b_k
0
_n0_n1_
k1_
block_desc
),
MPerWave
,
NPerWave
,
KP
erWave
>
{};
KP
ack
>
{};
constexpr
auto
CLayout
=
blockwise_gemm
.
GetCLayout
();
constexpr
index_t
BlkSize
=
CLayout
.
GetBlkSize
();
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
de9f5bed
...
...
@@ -547,7 +547,7 @@ struct xdlops_info
static
constexpr
index_t
GetKPerXdlops
()
{
return
mfma_type
.
k_base
*
(
IsKReduction
()
?
mfma_type
.
num_input_blks
:
1
)
;
return
IsKReduction
()
?
mfma_type
.
num_input_blks
:
1
;
}
static
constexpr
index_t
GetNumCRegs
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
...
...
@@ -555,7 +555,7 @@ struct xdlops_info
static
constexpr
auto
GetCType
()
{
return
CType_
{};
}
};
template
<
class
base_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KP
erWave
>
template
<
class
base_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KP
ack
>
struct
XdlopsGemm
{
template
<
class
base_type_
=
base_type
,
...
...
@@ -801,13 +801,13 @@ struct XdlopsGemm
is_same
<
base_type
,
ushort
>::
value
,
"base base_type must be float, half, ushort!"
);
static_assert
(
KP
erWave
%
KPerXdlops
==
0
,
"KP
erWave
cannot be divided by
KPerXdlops
"
);
static_assert
(
KP
ack
%
mfma_type
.
k_base
==
0
,
"KP
ack
cannot be divided by
k_base
"
);
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m0
,
n0
))
*
GetNumXdlops
();
static_for
<
0
,
KP
erWave
,
KPerXdlops
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m0
,
0
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
n0
,
0
));
static_for
<
0
,
KP
ack
,
mfma_type
.
k_base
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
0
,
m0
,
0
,
k
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
0
,
n0
,
0
,
k
));
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
p_a_wave
[
Number
<
a_offset
>
{}],
p_b_wave
[
Number
<
b_offset
>
{}],
p_c_thread
);
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
de9f5bed
...
...
@@ -88,7 +88,7 @@
// experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
0
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
1
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment