Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1e339898
Commit
1e339898
authored
Oct 18, 2024
by
aska-0096
Browse files
temp save
parent
11444e4c
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1251 additions
and
269 deletions
+1251
-269
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/gemm_xdl_fp16_v3.cpp
example/01_gemm/gemm_xdl_fp16_v3.cpp
+3
-3
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+26
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+138
-22
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
...ion/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
+6
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+41
-37
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+168
-205
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+827
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+41
-0
No files found.
example/01_gemm/CMakeLists.txt
View file @
1e339898
...
@@ -25,6 +25,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
...
@@ -25,6 +25,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
add_example_executable
(
example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_streamk_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_streamk_v3
)
add_example_executable
(
example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp
)
target_compile_options
(
example_gemm_xdl_fp16_v3 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v3
)
add_example_executable
(
example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_v3
)
...
...
example/01_gemm/gemm_xdl_fp16_v3.cpp
View file @
1e339898
...
@@ -19,7 +19,7 @@ using AElementOp = PassThrough;
...
@@ -19,7 +19,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
// clang-format off
using
DeviceGemmV2Instance
=
using
DeviceGemmV2Instance
=
...
@@ -29,13 +29,13 @@ using DeviceGemmV2Instance =
...
@@ -29,13 +29,13 @@ using DeviceGemmV2Instance =
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
224
,
256
,
224
,
256
,
64
,
8
,
2
,
64
,
8
,
8
,
16
,
16
,
16
,
16
,
7
,
8
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
2
,
0
,
1
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
>
;
// clang-format on
// clang-format on
...
...
include/ck/tensor/static_tensor.hpp
View file @
1e339898
...
@@ -218,6 +218,32 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -218,6 +218,32 @@ struct StaticTensorTupleOfVectorBuffer
}
}
}
}
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType_Print
(
Idx
,
X
x
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
if
(
get_thread_local_1d_id
()
==
0
){
printf
(
"Tid: %d, Index: (%d, %d, %d, %d), Offset: %d
\n
"
,
get_thread_local_1d_id
(),
Idx
{}.
At
(
Number
<
0
>
{}).
value
,
Idx
{}.
At
(
Number
<
1
>
{}).
value
,
Idx
{}.
At
(
Number
<
2
>
{}).
value
,
Idx
{}.
At
(
Number
<
3
>
{}).
value
,
offset
);
}
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
data_
.
template
SetAsType
<
X
>(
Number
<
offset
>
{},
x
);
}
}
// Get read access to V. No is_valid check
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
// Idx is for S, not V. Idx should be aligned with V
template
<
typename
Idx
>
template
<
typename
Idx
>
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
View file @
1e339898
...
@@ -30,6 +30,8 @@ template <index_t BlockSize,
...
@@ -30,6 +30,8 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
,
index_t
KPack
,
bool
TransposeA
=
false
,
bool
TransposeB
=
false
,
bool
TransposeC
=
false
>
bool
TransposeC
=
false
>
struct
BlockwiseGemmXdlops_pipeline_base
struct
BlockwiseGemmXdlops_pipeline_base
{
{
...
@@ -152,6 +154,38 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -152,6 +154,38 @@ struct BlockwiseGemmXdlops_pipeline_base
return
make_tuple
(
c_thread_m
,
c_thread_n
);
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
}
// Contiguous output tile
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndexContiguous
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NWaves
,
NPerXDL
,
NRepeat
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
waveId_n
,
blk_idx
[
I1
],
n0
))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
...
@@ -212,6 +246,21 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -212,6 +246,21 @@ struct BlockwiseGemmXdlops_pipeline_base
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
}
// Contiguous output tile
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
Number
<
MRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
N
,
Number
<
NRepeat
>
{},
M2
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
...
@@ -253,6 +302,23 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -253,6 +302,23 @@ struct BlockwiseGemmXdlops_pipeline_base
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
()
{
constexpr
auto
c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
(
c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
...
@@ -327,28 +393,78 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -327,28 +393,78 @@ struct BlockwiseGemmXdlops_pipeline_base
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
template
<
bool
Transpose
>
ComputeDataType
,
struct
AThreadCopySelector
;
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
template
<
>
Sequence
<
1
,
1
,
1
,
KPack
>
,
struct
AThreadCopySelector
<
false
>
Sequence
<
0
,
1
,
2
,
3
>
,
{
3
,
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
ADataType
,
A_K1
,
ComputeDataType
,
A_K1
>
;
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
Sequence
<
MRepeat
,
1
,
1
,
KPack
>
,
ComputeDataType
,
Sequence
<
0
,
1
,
2
,
3
>
,
decltype
(
b_block_desc_n0_n1_n2_k
),
Sequence
<
0
,
1
,
2
,
3
>
,
decltype
(
b_thread_desc_
),
3
,
Sequence
<
1
,
1
,
1
,
KPack
>
,
3
,
Sequence
<
0
,
1
,
2
,
3
>
,
A_K1
,
3
,
A_K1
>
;
B_K1
,
};
B_K1
>
;
template
<
>
AThreadCopy
a_thread_copy_
;
struct
AThreadCopySelector
<
true
>
BThreadCopy
b_thread_copy_
;
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
MRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
3
,
1
,
2
,
0
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
0
,
3
,
MRepeat
,
A_K1
>
;
};
template
<
bool
Transpose
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
3
,
1
,
2
,
0
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
0
,
3
,
NRepeat
,
B_K1
>
;
};
typename
AThreadCopySelector
<
TransposeA
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
TransposeB
>::
type
b_thread_copy_
;
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
View file @
1e339898
...
@@ -40,7 +40,9 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
...
@@ -40,7 +40,9 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
>
constexpr
auto
BlockGemmPipeline_Selector
()
constexpr
auto
BlockGemmPipeline_Selector
()
{
{
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
...
@@ -110,7 +112,9 @@ constexpr auto BlockGemmPipeline_Selector()
...
@@ -110,7 +112,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
TransposeA
,
TransposeB
>
{};
}
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
1e339898
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
>
struct
BlockwiseGemmXdlops_pipeline_v3
struct
BlockwiseGemmXdlops_pipeline_v3
{
{
};
};
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
// ,bool TransposeC //disable transposec right now...
// ,bool TransposeC //disable transposec right now...
>
>
struct
BlockwiseGemmXdlops_pipeline_v3
<
BlockGemmPipelineScheduler
::
Intrawave
,
struct
BlockwiseGemmXdlops_pipeline_v3
<
BlockGemmPipelineScheduler
::
Intrawave
,
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
{
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
;
KPack
,
TransposeA
,
TransposeB
>
;
using
Base
::
I0
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
using
Base
::
KRepeat
;
...
@@ -322,22 +332,19 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -322,22 +332,19 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
// Local prefetch 1
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
// make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
// a_block_buf,
a_block_buf
,
// a_thread_desc_,
a_thread_desc_
,
// make_tuple(I0, I0, k0, I0),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
// a_thread_buf);
a_thread_buf
);
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_thread_desc_
,
b_block_buf
,
make_tuple
(
I0
,
I0
,
k0
,
I0
),
b_thread_desc_
,
b_thread_buf
);
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -392,22 +399,19 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -392,22 +399,19 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
// make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
// a_block_buf,
a_block_buf
,
// a_thread_desc_,
a_thread_desc_
,
// make_tuple(I0, I0, k0, I0),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
// a_thread_buf);
a_thread_buf
);
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_thread_desc_
,
b_block_buf
,
make_tuple
(
I0
,
I0
,
k0
,
I0
),
b_thread_desc_
,
b_thread_buf
);
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
});
HotLoopScheduler
();
HotLoopScheduler
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
1e339898
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
1e339898
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
1e339898
...
@@ -950,6 +950,47 @@ struct XdlopsGemm
...
@@ -950,6 +950,47 @@ struct XdlopsGemm
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
}
}
template
<
typename
CDesc_MBlock_NBlock_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
(
const
CDesc_MBlock_NBlock_M0_N0_M1_N1_M2_N2
&
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
)
{
const
auto
MBlock
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
NBlock
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M0
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N0
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
const
auto
M1
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I4
);
const
auto
N1
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I5
);
return
transform_tensor_descriptor
(
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
MBlock
),
make_pass_through_transform
(
NBlock
),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
Number
<
mfma_instr
.
num_groups_per_blk
>
{},
Number
<
mfma_instr
.
num_input_blks
>
{},
Number
<
mfma_instr
.
group_size
>
{})),
make_pass_through_transform
(
Number
<
mfma_instr
.
num_threads_per_blk
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
8
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
9
>
{},
Sequence
<
7
>
{}));
}
// transposed XDL output supporting C' = B' * A'
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
// M2_N2 -> M2_N2_N3_N4
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment