Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
72a488cc
Commit
72a488cc
authored
Oct 23, 2024
by
aska-0096
Browse files
All layout sanity pass
parent
ea41fc2f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
743 additions
and
609 deletions
+743
-609
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+67
-10
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
+63
-70
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
+148
-93
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+377
-419
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+34
-0
include/ck/utility/transpose_vectors.hpp
include/ck/utility/transpose_vectors.hpp
+54
-17
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
View file @
72a488cc
...
...
@@ -154,12 +154,12 @@ struct BlockwiseGemmXdlops_pipeline_base
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
// Contiguous output tile
//
N
Contiguous output tile
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndexContiguous
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
__device__
static
auto
CalculateCThreadOriginDataIndex
N
Contiguous
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
...
...
@@ -186,6 +186,38 @@ struct BlockwiseGemmXdlops_pipeline_base
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
// MNContiguous output tile
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndexMNContiguous
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MWaves
,
MPerXDL
,
MRepeat
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NWaves
,
NPerXDL
,
NRepeat
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
waveId_m
,
blk_idx
[
I0
],
m0
))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
waveId_n
,
blk_idx
[
I1
],
n0
))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
...
...
@@ -246,19 +278,30 @@ struct BlockwiseGemmXdlops_pipeline_base
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
// Contiguous output tile
//
N-
Contiguous output tile
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
Number
<
MRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
N
,
Number
<
NRepeat
>
{},
M2
));
make_tuple
(
I1
,
I1
,
Number
<
MRepeat
>
{},
I1
,
I1
,
M0
,
I1
,
I1
,
Number
<
NRepeat
>
{},
M2
));
}
// MN-Contiguous output tile
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MBlock_NBlock_M0_N0_M1_M2_N1_M3_N2_M4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
I1
,
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
M2
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
...
...
@@ -315,6 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_base
return
xdlops_gemm
.
MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
// TransposeA
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_M2_N1_M3_N2_M4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
...
...
@@ -435,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_base
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
2
,
0
,
1
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
View file @
72a488cc
...
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
index_t
KPacks
,
bool
TransposeA
,
bool
TransposeB
>
struct
BlockwiseGemmXdlops_pipeline_v4
{
};
...
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v4
<
BlockGemmPipelineScheduler
::
Intrawave
,
...
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
...
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
...
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
KPack
,
TransposeA
,
TransposeB
>
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
...
...
@@ -298,22 +308,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
I0
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
I0
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
I0
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
I0
));
});
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
I0
),
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
a_thread_bufs
(
I0
));
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
I0
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
b_thread_bufs
(
I0
));
});
// Global prefetch 3
...
...
@@ -349,23 +355,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
a_blockwise_copy
.
RunWrite
(
...
...
@@ -430,22 +431,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
...
...
@@ -489,22 +486,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
View file @
72a488cc
...
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
index_t
KPacks
,
bool
TransposeA
,
bool
TransposeB
>
struct
BlockwiseGemmXdlops_pipeline_v5
{
};
...
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v5
<
BlockGemmPipelineScheduler
::
Intrawave
,
...
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
...
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
...
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
KPack
,
TransposeA
,
TransposeB
>
;
using
Base
::
A_K1
;
using
Base
::
B_K1
;
using
Base
::
I0
;
...
...
@@ -381,22 +391,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// main body
if
constexpr
(
HasMainLoop
)
...
...
@@ -449,25 +456,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
HotLoopScheduler
();
...
...
@@ -517,24 +522,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
HotLoopScheduler
();
...
...
@@ -567,25 +571,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
...
@@ -636,28 +638,81 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPack
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
template
<
bool
Transpose
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
MRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
A_K1
,
A_K1
>
;
};
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
MRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
3
,
1
,
2
,
0
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
0
,
3
,
MRepeat
,
A_K1
>
;
};
template
<
bool
Transpose
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
2
,
0
,
1
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
3
,
1
,
2
,
0
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
0
,
3
,
NRepeat
,
B_K1
>
;
};
typename
AThreadCopySelector
<
TransposeA
>::
type
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
typename
BThreadCopySelector
<
TransposeB
>::
type
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
using
Base
::
c_thread_desc_
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
72a488cc
...
...
@@ -235,25 +235,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{},
Number
<
MNXdlPerWave
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
1
,
2
,
0
>
{}));
#if 0
constexpr auto mma_transformed =
transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNWaves>{}, Number<MNPerXdl>{}, Number<MNXdlPerWave>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor(
mma_transformed,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_pass_through_transform(Number<MNWaves>{}),
make_pass_through_transform(Number<MNPerXdl>{}),
make_pass_through_transform(Number<MNXdlPerWave>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<0>{}));
#endif
}
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
...
...
@@ -448,20 +429,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
b_mma_desc
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
MakeGemmMmaTileDescriptorCongruous
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
}();
return
b_mma_desc
;
return
MakeGemmMmaTileDescriptorCongruous
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
__host__
__device__
static
auto
...
...
@@ -484,45 +453,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
struct
Problem
...
...
@@ -723,92 +653,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
// ColumnMajor A
{
#if 0
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / MPerXdl;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
? 1
: 128 / (AK1Number * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
? 1
: ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
? M0
: 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * M1>{},
Number<kfold * M0 / mpair>{},
Number<mpair>{},
AK1Number));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<M1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<M0 / mpair>{}, Number<mpair>{}, Number<M1>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
#endif
static_assert
(
ABlockTransferSrcScalarPerVector
%
MXdlPerWave
==
0
);
return
make_naive_tensor_descriptor
(
...
...
@@ -867,89 +711,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
// RowMajor B
{
#if 0
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
? 1
: 128 / (BK1Number * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
? 1
: ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
? N0
: 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * N1>{},
Number<kfold * N0 / npair>{},
Number<npair>{},
BK1Number));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<N1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<N0 / npair>{}, Number<npair>{}, Number<N1>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_lds_block_desc_bk0_n_bk1;
#endif
static_assert
(
BBlockTransferSrcScalarPerVector
%
NXdlPerWave
==
0
);
return
make_naive_tensor_descriptor
(
...
...
@@ -958,19 +719,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
__device__
static
constexpr
auto
GetC
ShuffleBlock
Descriptor_MBlock_
MPerBlock_NBlock_NPer
Block
()
__device__
static
constexpr
auto
GetC
Thread
Descriptor_MBlock_
N
Block
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
BlockwiseGemmPipe
::
GetCThreadDescriptor_MBlock_NBlock_M0_N0_M1_M2_N1_M3_N2_M4
();
}
else
{
return
BlockwiseGemmPipe
::
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
();
}
}
using
BlockwiseGemmPipe
=
...
...
@@ -1413,88 +1171,188 @@ struct GridwiseGemm_xdl_cshuffle_v3
num_k_block_main_loop
);
// Epilogue
constexpr
auto
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
();
constexpr
auto
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
();
constexpr
auto
M0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
M1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
N0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
M2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
M3
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
N1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
N2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
M4
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
7
>
{});
const
auto
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
transform_tensor_descriptor
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_pass_through_transform
(
problem
.
MBlock
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_pass_through_transform
(
problem
.
NBlock
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
,
3
,
5
,
6
,
9
>
{},
Sequence
<
1
>
{},
Sequence
<
4
,
7
,
8
>
{}));
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndexContiguous
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// Typecast -> Permute -> Coalesced vector store
auto
c_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r4
<
AccDataType
,
CDataType
,
decltype
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
decltype
(
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
CElementwiseOperation
,
Sequence
<
I1
,
I1
,
M0
,
I1
,
I1
,
M2
,
I1
,
I1
,
N2
,
M4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
9
,
8
>
,
9
,
8
,
M4
,
N2
,
InMemoryDataOperationEnum
::
Set
,
false
>
{
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
make_multi_index
(
block_m_id
,
block_n_id
,
m_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I4
]),
c_element_op
};
c_thread_copy_vgpr_to_global
.
Run
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
constexpr
auto
c_thread_desc_mblock_nblock
=
GetCThreadDescriptor_MBlock_NBlock
();
auto
c_block_trait
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
constexpr
auto
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
7
>
{});
const
auto
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4
=
transform_tensor_descriptor
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_pass_through_transform
(
problem
.
MBlock
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M4
,
M3
)),
make_pass_through_transform
(
problem
.
NBlock
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
,
4
,
5
,
9
,
7
>
{},
Sequence
<
1
>
{},
Sequence
<
3
,
6
,
8
>
{}));
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndexMNContiguous
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// Typecast -> Permute -> Coalesced vector store
auto
c_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r4
<
AccDataType
,
CDataType
,
decltype
(
c_thread_desc_mblock_nblock
),
decltype
(
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4
),
CElementwiseOperation
,
Sequence
<
I1
,
I1
,
I1
,
I1
,
M1
,
I1
,
I1
,
M3
,
N2
,
M4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
8
,
9
,
7
>
,
9
,
8
,
M4
,
N2
,
InMemoryDataOperationEnum
::
Set
,
false
>
{
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4
,
make_multi_index
(
block_m_id
,
block_n_id
,
m_thread_data_on_block_idx
[
I0
],
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I4
]),
c_element_op
};
return
make_tuple
(
c_thread_copy_vgpr_to_global
,
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4
);
}
else
{
constexpr
auto
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
();
constexpr
auto
M0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
M1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
N0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
M2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
M3
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
N1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
N2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
M4
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
7
>
{});
const
auto
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
transform_tensor_descriptor
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_pass_through_transform
(
problem
.
MBlock
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_pass_through_transform
(
problem
.
NBlock
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
,
3
,
5
,
6
,
9
>
{},
Sequence
<
1
>
{},
Sequence
<
4
,
7
,
8
>
{}));
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndexNContiguous
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// Typecast -> Permute -> Coalesced vector store
auto
c_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r4
<
AccDataType
,
CDataType
,
decltype
(
c_thread_desc_mblock_nblock
),
decltype
(
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
CElementwiseOperation
,
Sequence
<
I1
,
I1
,
M0
,
I1
,
I1
,
M2
,
I1
,
I1
,
N2
,
M4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
9
,
8
>
,
9
,
8
,
M4
,
N2
,
InMemoryDataOperationEnum
::
Set
,
false
>
{
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
make_multi_index
(
block_m_id
,
block_n_id
,
m_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I4
]),
c_element_op
};
return
make_tuple
(
c_thread_copy_vgpr_to_global
,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
);
}
}();
auto
c_thread_copy_vgpr_to_global
=
c_block_trait
.
At
(
Number
<
0
>
{});
auto
c_grid_desc_mblock_nblock
=
c_block_trait
.
At
(
Number
<
1
>
{});
c_thread_copy_vgpr_to_global
.
Run
(
c_thread_desc_mblock_nblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_mblock_nblock
_m0_m1_n0_m2_m3_n1_n2_m4
,
c_grid_desc_mblock_nblock
,
c_grid_buf
);
}
...
...
@@ -1607,11 +1465,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransfer
Dst
ScalarPerVector
_AK1
,
ABlockTransfer
Src
ScalarPerVector
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
...
@@ -1638,11 +1496,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransfer
Dst
ScalarPerVector
_BK1
,
BBlockTransfer
Src
ScalarPerVector
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
...
...
@@ -1706,88 +1564,188 @@ struct GridwiseGemm_xdl_cshuffle_v3
num_k_block_main_loop
);
// Epilogue
constexpr
auto
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
();
constexpr
auto
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
();
constexpr
auto
M0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
M1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
N0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
M2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
M3
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
N1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
N2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
M4
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
7
>
{});
const
auto
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
transform_tensor_descriptor
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_pass_through_transform
(
problem
.
MBlock
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_pass_through_transform
(
problem
.
NBlock
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
,
3
,
5
,
6
,
9
>
{},
Sequence
<
1
>
{},
Sequence
<
4
,
7
,
8
>
{}));
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndexContiguous
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// Typecast -> Permute -> Coalesced vector store
auto
c_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r4
<
AccDataType
,
CDataType
,
decltype
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
decltype
(
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
CElementwiseOperation
,
Sequence
<
I1
,
I1
,
M0
,
I1
,
I1
,
M2
,
I1
,
I1
,
N2
,
M4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
9
,
8
>
,
9
,
8
,
M4
,
N2
,
InMemoryDataOperationEnum
::
Set
,
false
>
{
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
make_multi_index
(
block_m_id
,
block_n_id
,
m_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I4
]),
c_element_op
};
c_thread_copy_vgpr_to_global
.
Run
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
constexpr
auto
c_thread_desc_mblock_nblock
=
GetCThreadDescriptor_MBlock_NBlock
();
auto
c_block_trait
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
constexpr
auto
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4
.
GetLength
(
Number
<
7
>
{});
const
auto
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4
=
transform_tensor_descriptor
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_pass_through_transform
(
problem
.
MBlock
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M4
,
M3
)),
make_pass_through_transform
(
problem
.
NBlock
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
,
4
,
5
,
9
,
7
>
{},
Sequence
<
1
>
{},
Sequence
<
3
,
6
,
8
>
{}));
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndexMNContiguous
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// Typecast -> Permute -> Coalesced vector store
auto
c_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r4
<
AccDataType
,
CDataType
,
decltype
(
c_thread_desc_mblock_nblock
),
decltype
(
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4
),
CElementwiseOperation
,
Sequence
<
I1
,
I1
,
I1
,
I1
,
M1
,
I1
,
I1
,
M3
,
N2
,
M4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
8
,
9
,
7
>
,
9
,
8
,
M4
,
N2
,
InMemoryDataOperationEnum
::
Set
,
false
>
{
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4
,
make_multi_index
(
block_m_id
,
block_n_id
,
m_thread_data_on_block_idx
[
I0
],
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I4
]),
c_element_op
};
return
make_tuple
(
c_thread_copy_vgpr_to_global
,
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4
);
}
else
{
constexpr
auto
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
();
constexpr
auto
M0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
M1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
N0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
M2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
M3
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
N1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
N2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
M4
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
7
>
{});
const
auto
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
transform_tensor_descriptor
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_pass_through_transform
(
problem
.
MBlock
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_pass_through_transform
(
problem
.
NBlock
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
,
3
,
5
,
6
,
9
>
{},
Sequence
<
1
>
{},
Sequence
<
4
,
7
,
8
>
{}));
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndexNContiguous
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// Typecast -> Permute -> Coalesced vector store
auto
c_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r4
<
AccDataType
,
CDataType
,
decltype
(
c_thread_desc_mblock_nblock
),
decltype
(
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
CElementwiseOperation
,
Sequence
<
I1
,
I1
,
M0
,
I1
,
I1
,
M2
,
I1
,
I1
,
N2
,
M4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
9
,
8
>
,
9
,
8
,
M4
,
N2
,
InMemoryDataOperationEnum
::
Set
,
false
>
{
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
make_multi_index
(
block_m_id
,
block_n_id
,
m_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I4
]),
c_element_op
};
return
make_tuple
(
c_thread_copy_vgpr_to_global
,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
);
}
}();
auto
c_thread_copy_vgpr_to_global
=
c_block_trait
.
At
(
Number
<
0
>
{});
auto
c_grid_desc_mblock_nblock
=
c_block_trait
.
At
(
Number
<
1
>
{});
c_thread_copy_vgpr_to_global
.
Run
(
c_thread_desc_mblock_nblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_mblock_nblock
_m0_m1_n0_m2_m3_n1_n2_m4
,
c_grid_desc_mblock_nblock
,
c_grid_buf
);
}
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
72a488cc
...
...
@@ -983,6 +983,40 @@ struct XdlopsGemm
Sequence
<
5
>
{}));
}
// TransposeA
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_M2_N1_M3_N2_M4
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
{
const
auto
M0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
N0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
return
transform_tensor_descriptor
(
c_desc_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
Number
<
mfma_instr
.
num_groups_per_blk
>
{},
Number
<
mfma_instr
.
num_input_blks
>
{},
Number
<
mfma_instr
.
group_size
>
{})),
make_pass_through_transform
(
Number
<
mfma_instr
.
num_threads_per_blk
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
7
>
{},
Sequence
<
6
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
,
5
>
{},
Sequence
<
4
>
{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
...
...
include/ck/utility/transpose_vectors.hpp
View file @
72a488cc
...
...
@@ -18,22 +18,6 @@ struct transpose_vectors;
// transpose fp16 2x2
__device__
void
transpose_fp16_2x2
(
const
half2_t
&
x0
,
const
half2_t
&
x1
,
half2_t
&
y0
,
half2_t
&
y1
)
{
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx0{x0}, vx1{x1};
vector_type<half_t, 2> vy0, vy1;
vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
constexpr
int32_t
m0
=
0x05040100
;
constexpr
int32_t
m1
=
0x07060302
;
...
...
@@ -43,7 +27,6 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
// index is reversed because of little endianness (least significant bits first)
y0
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
));
y1
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m1
));
#endif
}
template
<
index_t
NX
,
index_t
NY
>
...
...
@@ -83,6 +66,60 @@ struct transpose_vectors<half_t, NX, NY>
}
};
// transpose bf16 2x2
__device__
void
transpose_bf16_2x2
(
const
bhalf2_t
&
x0
,
const
bhalf2_t
&
x1
,
bhalf2_t
&
y0
,
bhalf2_t
&
y1
)
{
constexpr
int32_t
m0
=
0x05040100
;
constexpr
int32_t
m1
=
0x07060302
;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
y0
=
bit_cast
<
bhalf2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
));
y1
=
bit_cast
<
bhalf2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m1
));
}
template
<
index_t
NX
,
index_t
NY
>
struct
transpose_vectors
<
bhalf_t
,
NX
,
NY
>
{
// we got [NY * NX] amount of S data to be transposed
static
constexpr
index_t
s_per_x
=
NY
;
static
constexpr
index_t
s_per_y
=
NX
;
using
S
=
bhalf_t
;
using
VX
=
vector_type
<
bhalf_t
,
s_per_x
>
;
using
VY
=
vector_type
<
bhalf_t
,
s_per_y
>
;
__device__
void
operator
()(
const
StaticallyIndexedArray
<
const
VX
&
,
NX
>&
vx_tuple
,
StaticallyIndexedArray
<
VY
&
,
NY
>&
vy_tuple
)
{
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
((
NX
%
2
==
0
&&
NY
%
2
==
0
),
"wrong!"
);
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
2
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
2
>
{}([
&
](
auto
ix
)
{
// reference to 2 bhalf2_t data from vx_tuple
const
auto
&
x_s2_0
=
vx_tuple
[
ix
].
template
AsType
<
bhalf2_t
>()[
iy
/
I2
];
const
auto
&
x_s2_1
=
vx_tuple
[
ix
+
I1
].
template
AsType
<
bhalf2_t
>()[
iy
/
I2
];
// reference to 2 bhalf2_t data from vy_tuple
auto
&
y_s2_0
=
vy_tuple
(
iy
).
template
AsType
<
bhalf2_t
>()(
ix
/
I2
);
auto
&
y_s2_1
=
vy_tuple
(
iy
+
I1
).
template
AsType
<
bhalf2_t
>()(
ix
/
I2
);
// transpose
transpose_bf16_2x2
(
x_s2_0
,
x_s2_1
,
y_s2_0
,
y_s2_1
);
});
});
}
};
// transpose int8 4x4
__device__
void
transpose_int8_4x4
(
const
int8x4_t
&
x0
,
const
int8x4_t
&
x1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment