Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
da59d3b2
Commit
da59d3b2
authored
Dec 12, 2024
by
coderfeli
Browse files
remove useless comments and changes
parent
6fd51c43
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
145 additions
and
351 deletions
+145
-351
example/01_gemm/gemm_xdl_fp16_v2.cpp
example/01_gemm/gemm_xdl_fp16_v2.cpp
+4
-6
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+30
-14
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+4
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
+0
-11
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+28
-28
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+0
-5
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+0
-11
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
+0
-39
include/ck_tile/core/tensor/tensor_coordinate.hpp
include/ck_tile/core/tensor/tensor_coordinate.hpp
+0
-10
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+2
-8
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+76
-126
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
...m/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
+1
-92
No files found.
example/01_gemm/gemm_xdl_fp16_v2.cpp
View file @
da59d3b2
...
@@ -15,7 +15,7 @@ using F16 = ck::half_t;
...
@@ -15,7 +15,7 @@ using F16 = ck::half_t;
using
F32
=
float
;
using
F32
=
float
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
...
@@ -32,17 +32,15 @@ using DeviceGemmInstance =
...
@@ -32,17 +32,15 @@ using DeviceGemmInstance =
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
2
,
256
,
256
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
8
,
4
,
32
,
32
,
32
,
32
,
4
,
4
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
2
,
8
,
8
,
0
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
LoopScheduler
::
Default
,
ck
::
PipelineVersion
::
v1
>
;
ck
::
LoopScheduler
::
Default
,
ck
::
PipelineVersion
::
v1
>
;
//./bin/example_gemm_xdl_fp16_v2 0 0 1 5120 5120 8320 8320 8320 5120
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
da59d3b2
...
@@ -22,12 +22,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -22,12 +22,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
...
@@ -40,6 +43,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -40,6 +43,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
...
@@ -47,21 +52,27 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -47,21 +52,27 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
// constexpr ck_tile::index_t Warp_Size = 64;
// using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
using
GemmEpilogue
=
std
::
conditional_t
<
// CDataType,
CShuffleEpilogue
,
// M_Warp * N_Warp * K_Warp * Warp_Size,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
// 64,
CDataType
,
// TilePartitioner::kN,
kPadM
,
// kPadM,
kPadN
,
// kPadN>>;
kTilePermute
,
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
kOutputRank
,
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
true
,
2
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
GemmPipelineA
GmemBGmemCRegV1Default
Policy
;
using
CodegenGemmPolicy
=
ck_tile
::
Universal
GemmPipelineA
gBgCr
Policy
;
using
CodegenGemmPipeline
=
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
...
@@ -81,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -81,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
{
{
std
::
cout
<<
"Launching kernel with args:"
std
::
cout
<<
"Launching kernel with args:"
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
da59d3b2
...
@@ -119,9 +119,12 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -119,9 +119,12 @@ int run_gemm_example_with_layouts(int argc,
}
else
if
(
init_method
==
1
)
{
}
else
if
(
init_method
==
1
)
{
ck_tile
::
FillMonotonicSeq
<
ADataType
>
{}(
a_m_k
);
ck_tile
::
FillMonotonicSeq
<
ADataType
>
{}(
a_m_k
);
ck_tile
::
FillMonotonicSeq
<
BDataType
>
{}(
b_k_n
);
ck_tile
::
FillMonotonicSeq
<
BDataType
>
{}(
b_k_n
);
}
else
{
}
else
if
(
init_method
==
2
)
{
ck_tile
::
FillConstant
<
ADataType
>
{
1.
f
}(
a_m_k
);
ck_tile
::
FillConstant
<
ADataType
>
{
1.
f
}(
a_m_k
);
ck_tile
::
FillConstant
<
BDataType
>
{
1.
f
}(
b_k_n
);
ck_tile
::
FillConstant
<
BDataType
>
{
1.
f
}(
b_k_n
);
}
else
{
a_m_k
.
SetZero
();
b_k_n
.
SetZero
();
}
}
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
View file @
da59d3b2
...
@@ -46,17 +46,6 @@ template <typename ALayout,
...
@@ -46,17 +46,6 @@ template <typename ALayout,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
NXdlPerWave
,
// 2, 256,
// 256, 256,
// 32, 8, 8,
// 32, 32,
// 4, 4,
// S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// 1, 1, S<1, 32, 1, 8>, 8,
// ck::LoopScheduler::Default, ck::PipelineVersion::v1>;
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
da59d3b2
...
@@ -146,7 +146,7 @@ struct pass_through : public base_transform<1, 1>
...
@@ -146,7 +146,7 @@ struct pass_through : public base_transform<1, 1>
//
//
printf
(
"up_lengths_:"
);
printf
(
"up_lengths_:"
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
//
//
printf
(
"}"
);
printf
(
"}"
);
...
@@ -236,17 +236,17 @@ struct pad : public base_transform<1, 1>
...
@@ -236,17 +236,17 @@ struct pad : public base_transform<1, 1>
//
//
printf
(
"up_lengths_: "
);
printf
(
"up_lengths_: "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"left_pad_length_: "
);
printf
(
"left_pad_length_: "
);
print
x
(
left_pad_length_
);
print
(
left_pad_length_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"right_pad_length_: "
);
printf
(
"right_pad_length_: "
);
print
x
(
right_pad_length_
);
print
(
right_pad_length_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -337,12 +337,12 @@ struct left_pad
...
@@ -337,12 +337,12 @@ struct left_pad
//
//
printf
(
"up_lengths_: "
);
printf
(
"up_lengths_: "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"left_pad_length_: "
);
printf
(
"left_pad_length_: "
);
print
x
(
left_pad_length_
);
print
(
left_pad_length_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -437,12 +437,12 @@ struct right_pad : public base_transform<1, 1>
...
@@ -437,12 +437,12 @@ struct right_pad : public base_transform<1, 1>
//
//
printf
(
"up_lengths_: "
);
printf
(
"up_lengths_: "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"right_pad_length_: "
);
printf
(
"right_pad_length_: "
);
print
x
(
right_pad_length_
);
print
(
right_pad_length_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -539,12 +539,12 @@ struct embed : public base_transform<1, UpLengths::size()>
...
@@ -539,12 +539,12 @@ struct embed : public base_transform<1, UpLengths::size()>
//
//
printf
(
"up_lengths_: "
);
printf
(
"up_lengths_: "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"coefficients_: "
);
printf
(
"coefficients_: "
);
print
x
(
coefficients_
);
print
(
coefficients_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -706,12 +706,12 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
...
@@ -706,12 +706,12 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
//
//
printf
(
"low_lengths_ "
);
printf
(
"low_lengths_ "
);
print
x
(
low_lengths_
);
print
(
low_lengths_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"up_lengths_ "
);
printf
(
"up_lengths_ "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -837,17 +837,17 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
...
@@ -837,17 +837,17 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
//
//
printf
(
"low_lengths_ "
);
printf
(
"low_lengths_ "
);
print
x
(
low_lengths_
);
print
(
low_lengths_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"low_lengths_scan_ "
);
printf
(
"low_lengths_scan_ "
);
print
x
(
low_lengths_scan_
);
print
(
low_lengths_scan_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"up_lengths_ "
);
printf
(
"up_lengths_ "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -965,12 +965,12 @@ struct unmerge : public base_transform<1, UpLengths::size()>
...
@@ -965,12 +965,12 @@ struct unmerge : public base_transform<1, UpLengths::size()>
//
//
printf
(
"up_lengths_"
);
printf
(
"up_lengths_"
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"up_lengths_scan_"
);
printf
(
"up_lengths_scan_"
);
print
x
(
up_lengths_scan_
);
print
(
up_lengths_scan_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -1030,7 +1030,7 @@ struct freeze : public base_transform<1, 0>
...
@@ -1030,7 +1030,7 @@ struct freeze : public base_transform<1, 0>
//
//
printf
(
"low_idx_: "
);
printf
(
"low_idx_: "
);
print
x
(
low_idx_
);
print
(
low_idx_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -1098,7 +1098,7 @@ struct insert : public base_transform<0, 1>
...
@@ -1098,7 +1098,7 @@ struct insert : public base_transform<0, 1>
printf
(
"insert{"
);
printf
(
"insert{"
);
//
//
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -1158,7 +1158,7 @@ struct replicate : public base_transform<0, UpLengths::size()>
...
@@ -1158,7 +1158,7 @@ struct replicate : public base_transform<0, UpLengths::size()>
//
//
printf
(
"up_lengths_: "
);
printf
(
"up_lengths_: "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -1245,17 +1245,17 @@ struct slice : public base_transform<1, 1>
...
@@ -1245,17 +1245,17 @@ struct slice : public base_transform<1, 1>
//
//
printf
(
"up_lengths_: "
);
printf
(
"up_lengths_: "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"slice_begin_: "
);
printf
(
"slice_begin_: "
);
print
x
(
slice_begin_
);
print
(
slice_begin_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"slice_end_: "
);
printf
(
"slice_end_: "
);
print
x
(
slice_end_
);
print
(
slice_end_
);
printf
(
"}"
);
printf
(
"}"
);
}
// namespace ck
}
// namespace ck
...
@@ -1335,7 +1335,7 @@ struct modulo : public base_transform<1, 1>
...
@@ -1335,7 +1335,7 @@ struct modulo : public base_transform<1, 1>
//
//
printf
(
"up_lengths_: "
);
printf
(
"up_lengths_: "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -1431,7 +1431,7 @@ struct xor_t : public base_transform<2, 2>
...
@@ -1431,7 +1431,7 @@ struct xor_t : public base_transform<2, 2>
//
//
printf
(
"up_lengths_: "
);
printf
(
"up_lengths_: "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
", "
);
printf
(
"}"
);
printf
(
"}"
);
...
@@ -1516,12 +1516,12 @@ struct offset : public base_transform<1, 1>
...
@@ -1516,12 +1516,12 @@ struct offset : public base_transform<1, 1>
//
//
printf
(
"up_lengths_: "
);
printf
(
"up_lengths_: "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
", "
);
//
//
printf
(
"offset_length_: "
);
printf
(
"offset_length_: "
);
print
x
(
offset_length_
);
print
(
offset_length_
);
printf
(
"}"
);
printf
(
"}"
);
}
}
...
@@ -1602,7 +1602,7 @@ struct indexing : public base_transform<1, 1>
...
@@ -1602,7 +1602,7 @@ struct indexing : public base_transform<1, 1>
//
//
printf
(
"up_lengths_: "
);
printf
(
"up_lengths_: "
);
print
x
(
up_lengths_
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
", "
);
printf
(
"}"
);
printf
(
"}"
);
...
...
include/ck_tile/core/container/tuple.hpp
View file @
da59d3b2
...
@@ -195,11 +195,6 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
...
@@ -195,11 +195,6 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
using
base
=
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
T
)
>
,
T
...
>
;
using
base
=
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
T
)
>
,
T
...
>
;
CK_TILE_HOST_DEVICE
constexpr
tuple
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tuple
()
=
default
;
CK_TILE_HOST_DEVICE
void
print
()
const
{
// printf("tuple{size: %d, data: [", size());
// ((printf("%d ", Is)), ...);
// printf("]}");
}
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
template
<
typename
U
>
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
std
::
initializer_list
<
U
>
us
)
:
base
(
us
)
CK_TILE_HOST_DEVICE
constexpr
tuple
(
std
::
initializer_list
<
U
>
us
)
:
base
(
us
)
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
da59d3b2
...
@@ -201,17 +201,6 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
...
@@ -201,17 +201,6 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return
unpacks
;
return
unpacks
;
}
}
template
<
typename
StaticTensor
>
CK_TILE_DEVICE
void
dump_static_tensor
(
StaticTensor
&
t
){
constexpr
auto
span_2d
=
decltype
(
t
)
::
get_distributed_spans
();
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
printf
(
"%f,"
,
type_convert
<
float
>
(
t
(
i_j_idx
)));
});
printf
(
"
\n
"
);
});
}
namespace
detail
{
namespace
detail
{
// check if 2 static_distributed_tensor has same data type and size of element
// check if 2 static_distributed_tensor has same data type and size of element
...
...
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
View file @
da59d3b2
...
@@ -89,45 +89,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor&
...
@@ -89,45 +89,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor&
remove_cvref_t
<
decltype
(
top_dim_ids
)
>>
{
idx_hidden
};
remove_cvref_t
<
decltype
(
top_dim_ids
)
>>
{
idx_hidden
};
}
}
// template <typename Adaptor, typename TopIndex>
// CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate_debug(const Adaptor& adaptor,
// const TopIndex& idx_top)
// {
// static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
// "wrong! # of dimension inconsistent");
// constexpr index_t ntransform = Adaptor::get_num_of_transform();
// constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
// constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
// constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
// multi_index<ndim_hidden> idx_hidden;
// // idx_hidden.print();
// // initialize visible index
// set_container_subset(idx_hidden, top_dim_ids, idx_top);
// // calculate hidden index
// static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
// auto itran = itran_p1 - number<1>{};
// const auto& tran = adaptor.get_transforms().at(itran);
// tran.print();
// constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
// constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
// const auto idx_up = get_container_subset(idx_hidden, dims_up);
// multi_index<dims_low.size()> idx_low;
// tran.calculate_lower_index(idx_low, idx_up);
// set_container_subset(idx_hidden, dims_low, idx_low);
// idx_hidden.print();
// });
// return tensor_adaptor_coordinate<ndim_hidden,
// remove_cvref_t<decltype(bottom_dim_ids)>,
// remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
// }
template
<
bool
JudgeDoTransforms
=
true
,
template
<
bool
JudgeDoTransforms
=
true
,
typename
Adaptor
,
typename
Adaptor
,
typename
AdaptorCoord
,
typename
AdaptorCoord
,
...
...
include/ck_tile/core/tensor/tensor_coordinate.hpp
View file @
da59d3b2
...
@@ -66,16 +66,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tens
...
@@ -66,16 +66,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tens
remove_cvref_t
<
decltype
(
TensorDesc
::
get_top_dimension_hidden_ids
())
>>
{
remove_cvref_t
<
decltype
(
TensorDesc
::
get_top_dimension_hidden_ids
())
>>
{
adaptor_coord
};
adaptor_coord
};
}
}
// template <typename TensorDesc, typename TopIndex>
// CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate_debug(const TensorDesc& tensor_desc,
// const TopIndex& idx_top)
// {
// const auto adaptor_coord = make_tensor_adaptor_coordinate_debug(tensor_desc, idx_top);
// return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
// remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
// adaptor_coord};
// }
template
<
bool
JudgeDoTransforms
=
true
,
typename
TensorDesc
,
typename
TensorCoord
,
typename
Index
>
template
<
bool
JudgeDoTransforms
=
true
,
typename
TensorDesc
,
typename
TensorCoord
,
typename
Index
>
CK_TILE_HOST_DEVICE
constexpr
void
CK_TILE_HOST_DEVICE
constexpr
void
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
da59d3b2
...
@@ -440,13 +440,6 @@ struct tile_window_linear
...
@@ -440,13 +440,6 @@ struct tile_window_linear
// we directly use BottomTensorView transform to compute the offset, in case padding
// we directly use BottomTensorView transform to compute the offset, in case padding
auto
bottom_tensor_coord
=
auto
bottom_tensor_coord
=
make_tensor_coordinate
(
BottomTensorView
{}.
get_tensor_descriptor
(),
linear_coord
);
make_tensor_coordinate
(
BottomTensorView
{}.
get_tensor_descriptor
(),
linear_coord
);
// if(threadIdx.x == 0) {
// bottom_tensor_coord =
// make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
// printf("off00 %d %d\n",i_access, bottom_tensor_coord.get_offset() );
// bottom_tensor_coord.get_hidden_index().print();
// bottom_tensor_coord.get_index().print();
// }
return
bottom_tensor_coord
.
get_offset
();
return
bottom_tensor_coord
.
get_offset
();
}
}
else
else
...
@@ -550,7 +543,8 @@ struct tile_window_linear
...
@@ -550,7 +543,8 @@ struct tile_window_linear
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
// read from bottom tensor
// read from bottom tensor
const
vector_t
vec_value
=
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
da59d3b2
...
@@ -53,82 +53,99 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -53,82 +53,99 @@ struct BlockGemmASmemBSmemCRegV1
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
//
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr
index_t
MPerBlockPerIter
=
MPerBlock
/
MIterPerWarp
;
//
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
//
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
//
const index_t iMWarp = get_warp_id() / NWarp;
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
//
const index_t iNWarp = get_warp_id() % NWarp;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
// if(threadIdx.x == 0 && blockIdx.x==0) {
// construct A-warp-window
// printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter);
// }
// MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8
auto
a_warp_window_tmp
=
make_tile_window
(
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
MPerBlock
,
KPerBlock
),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
{
0
,
0
},
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
Policy
::
template
MakeALDSTileDistribution
<
Problem
>());
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
a_warp_window_tmp
),
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
MPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
NPerBlock
,
KPerBlock
),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
{
0
,
0
},
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
Policy
::
template
MakeBLDSTileDistribution
<
Problem
>());
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
auto
a_block_tensor
=
load_tile
(
a_warp_window_tmp
);
#if 0 // FIXME: using array will cause register spill
auto
b_block_tensor
=
load_tile
(
b_warp_window_tmp
);
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
// if (threadIdx.x == 0) {
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
// printf("0\n");
{
// constexpr auto span_2d = decltype(a_block_tensor)::get_distributed_spans();
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
{
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
move_tile_window(b_warp_windows(nIter)(kIter),
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// printf("%f %f,", type_convert<float>(a_block_tensor(i_j_idx)), type_convert<float>(b_block_tensor(i_j_idx)));
}
// });
}
// printf("\n");
#else
// });
statically_indexed_array
<
// }
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
// __syncthreads();
NIterPerWarp
>
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
b_warp_windows
;
using
BWarpDstr
=
typename
WG
::
BWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
using
BWarpTensor
=
typename
WG
::
BWarpTensor
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A Block window
// read A warp tensor from A block window
AWarpTensor
a_warp_tensor
;
const
auto
a_warp_tensor
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B block tensor
// read B warp tensor from B Block window
BWarpTensor
b_warp_tensor
;
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
b_warp_tensor
.
get_thread_buffer
()
=
b_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
nIter
,
kIter
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
// read C warp tensor from C block tensor
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
CWarpTensor
c_warp_tensor
;
...
@@ -192,72 +209,5 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -192,72 +209,5 @@ struct BlockGemmASmemBSmemCRegV1
return
c_block_tensor
;
return
c_block_tensor
;
}
}
};
};
// construct A-warp-window
// auto a_warp_window_tmp = make_tile_window(
// a_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kM>{}, number<WG::kK>{}),
// a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
// make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
// {a_warp_window_tmp}};
// for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
// MIterPerWarp>
// a_warp_windows;
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
// construct B-warp-window
// auto b_warp_window_tmp = make_tile_window(
// b_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kN>{}, number<WG::kK>{}),
// b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
// make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
// {b_warp_window_tmp}};
// for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
// NIterPerWarp>
// b_warp_windows;
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
View file @
da59d3b2
...
@@ -40,8 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
...
@@ -40,8 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
}
#else
#else
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
{},
2
,
2
);
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
// return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif
#endif
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
...
@@ -55,96 +54,6 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
...
@@ -55,96 +54,6 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
static_assert
(
false
,
"Unsupported data type configuration for GEMM warp execution."
);
static_assert
(
false
,
"Unsupported data type configuration for GEMM warp execution."
);
}
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALDSTileDistribution
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
static_assert
(
false
,
"Unsupported tensor_layout right now."
);
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr
index_t
K2
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K1
=
2
;
constexpr
index_t
K0
=
KPerBlock
/
K1
/
K2
;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr
index_t
M2
=
32
;
// MPERXDL
constexpr
index_t
M1
=
2
;
//MWAVE
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
2
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
else
{
static_assert
(
false
,
"Unsupported shape right now."
);
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLDSTileDistribution
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
static_assert
(
false
,
"Unsupported tensor_layout right now."
);
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr
index_t
K2
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K1
=
2
;
constexpr
index_t
K0
=
KPerBlock
/
K1
/
K2
;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr
index_t
N2
=
32
;
// MPERXDL
constexpr
index_t
N1
=
2
;
//MWAVE
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
2
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
else
{
static_assert
(
false
,
"Unsupported shape right now."
);
}
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment