Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e2a318bc
Commit
e2a318bc
authored
Nov 12, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into ck_tile/moe_quant
parents
d0405504
2b6458dd
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
840 additions
and
254 deletions
+840
-254
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+40
-23
example/ck_tile/03_gemm/README.md
example/ck_tile/03_gemm/README.md
+3
-0
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+10
-9
example/ck_tile/03_gemm/gemm_mem_pipeline.cpp
example/ck_tile/03_gemm/gemm_mem_pipeline.cpp
+5
-5
include/ck_tile/core/tensor/shuffle_tile.hpp
include/ck_tile/core/tensor/shuffle_tile.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+2
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+50
-20
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+44
-19
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+267
-63
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+116
-38
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+264
-52
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+10
-6
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+16
-7
script/process_perf_data.py
script/process_perf_data.py
+2
-2
script/process_qa_data.sh
script/process_qa_data.sh
+1
-0
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
+6
-6
No files found.
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
e2a318bc
...
@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[])
...
@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[])
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"x_stride"
,
"-1"
,
"x row_stride, if -1 then equal to n"
)
.
insert
(
"xr_stride"
,
"-1"
,
"x residule row_stride, if -1 then equal to n"
)
.
insert
(
"y_stride"
,
"-1"
,
"y row_stride, if -1 then equal to n"
)
.
insert
(
"yr_stride"
,
"-1"
,
"y residule row_stride, if -1 then equal to n"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"save_mv"
,
"0"
,
"save mean/variance(invstd) or not. set to 1 in training case"
)
.
insert
(
"save_mv"
,
"0"
,
"save mean/variance(invstd) or not. set to 1 in training case"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
...
@@ -54,11 +57,20 @@ template <typename InDataType,
...
@@ -54,11 +57,20 @@ template <typename InDataType,
bool
SaveMeanVar
>
bool
SaveMeanVar
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
x_stride
=
arg_parser
.
get_int
(
"x_stride"
);
if
(
stride
<
0
)
if
(
x_stride
<
0
)
stride
=
n
;
x_stride
=
n
;
ck_tile
::
index_t
xr_stride
=
arg_parser
.
get_int
(
"xr_stride"
);
if
(
xr_stride
<
0
)
xr_stride
=
n
;
ck_tile
::
index_t
y_stride
=
arg_parser
.
get_int
(
"y_stride"
);
if
(
y_stride
<
0
)
y_stride
=
n
;
ck_tile
::
index_t
yr_stride
=
arg_parser
.
get_int
(
"yr_stride"
);
if
(
yr_stride
<
0
)
yr_stride
=
n
;
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
...
@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return
false
;
return
false
;
}
}
assert
(
stride
>=
n
);
assert
(
x_
stride
>=
n
);
using
TypeConfig
=
LayerNormTypeConfig
<
InDataType
,
OutDataType
,
XScaleDataType
,
YScaleDataType
>
;
using
TypeConfig
=
LayerNormTypeConfig
<
InDataType
,
OutDataType
,
XScaleDataType
,
YScaleDataType
>
;
...
@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_
stride
,
1
});
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
beta_host
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
beta_host
({
n
});
ck_tile
::
HostTensor
<
XResidualDataType
>
x_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XResidualDataType
>
x_residual_host
({
m
,
n
},
{
xr_
stride
,
1
});
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host
({
m
,
n
},
{
yr_
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
m
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
m
});
...
@@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}();
}();
std
::
cout
<<
"["
<<
prec_str
<<
"]"
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", x_stride:"
<<
x_stride
<<
", xr_stride:"
<<
xr_stride
<<
", y_stride:"
<<
y_stride
<<
", yr_stride:"
<<
yr_stride
<<
std
::
flush
;
layernorm2d_fwd_traits
traits
{
layernorm2d_fwd_traits
traits
{
prec_i
,
prec_o
,
prec_sx
,
prec_sy
,
SaveMeanVar
,
fused_add
,
fused_quant
};
prec_i
,
prec_o
,
prec_sx
,
prec_sy
,
SaveMeanVar
,
fused_add
,
fused_quant
};
...
@@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
epsilon
,
epsilon
,
m
,
m
,
n
,
n
,
stride
};
x_stride
,
// x row_stride
xr_stride
,
// x residule row stride
y_stride
,
// y row stride
yr_stride
};
// y residule row stride
float
ave_time
=
layernorm2d_fwd
(
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
...
@@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf
.
FromDevice
(
y_host_dev
.
data
());
y_buf
.
FromDevice
(
y_host_dev
.
data
());
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host_dev
({
m
,
n
},
{
yr_
stride
,
1
});
if
(
fused_add
==
1
)
if
(
fused_add
==
1
)
{
{
y_residual_buf
.
FromDevice
(
y_residual_host_dev
.
data
());
y_residual_buf
.
FromDevice
(
y_residual_host_dev
.
data
());
...
@@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto
[
rtol
,
atol
]
=
get_elimit
<
InDataType
>
();
auto
[
rtol
,
atol
]
=
get_elimit
<
InDataType
>
();
if
(
stride
==
n
)
if
(
x_
stride
==
n
)
{
{
pass
=
ck_tile
::
check_err
(
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
y_host_dev
,
y_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
...
@@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
{
std
::
vector
<
YDataType
>
y_host_dev_row
(
y_host_dev
.
begin
()
+
i_r
*
stride
,
std
::
vector
<
YDataType
>
y_host_dev_row
(
y_host_dev
.
begin
()
+
i_r
*
y_
stride
,
y_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
y_host_dev
.
begin
()
+
i_r
*
y_
stride
+
n
);
std
::
vector
<
YDataType
>
y_host_ref_row
(
y_host_ref
.
begin
()
+
i_r
*
stride
,
std
::
vector
<
YDataType
>
y_host_ref_row
(
y_host_ref
.
begin
()
+
i_r
*
y_
stride
,
y_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
y_host_ref
.
begin
()
+
i_r
*
y_
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
y_host_dev_row
,
pass
&=
ck_tile
::
check_err
(
y_host_dev_row
,
y_host_ref_row
,
y_host_ref_row
,
std
::
string
(
"OUT["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"OUT["
)
+
std
::
to_string
(
i_r
)
+
...
@@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
fused_add
==
1
)
if
(
fused_add
==
1
)
{
{
std
::
vector
<
YResidualDataType
>
y_residual_host_dev_row
(
std
::
vector
<
YResidualDataType
>
y_residual_host_dev_row
(
y_residual_host_dev
.
begin
()
+
i_r
*
stride
,
y_residual_host_dev
.
begin
()
+
i_r
*
yr_
stride
,
y_residual_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
y_residual_host_dev
.
begin
()
+
i_r
*
yr_
stride
+
n
);
std
::
vector
<
YResidualDataType
>
y_residual_host_ref_row
(
std
::
vector
<
YResidualDataType
>
y_residual_host_ref_row
(
x_host
.
begin
()
+
i_r
*
stride
,
x_host
.
begin
()
+
i_r
*
stride
+
n
);
x_host
.
begin
()
+
i_r
*
yr_
stride
,
x_host
.
begin
()
+
i_r
*
yr_
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev_row
,
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev_row
,
y_residual_host_ref_row
,
y_residual_host_ref_row
,
std
::
string
(
"ADD["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"ADD["
)
+
std
::
to_string
(
i_r
)
+
...
...
example/ck_tile/03_gemm/README.md
View file @
e2a318bc
...
@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
...
@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
mkdir build && cd build
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j
make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_mem_pipeline -j
```
```
This will result in an executable
`build/bin/tile_example_gemm_basic`
This will result in an executable
`build/bin/tile_example_gemm_basic`
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
e2a318bc
...
@@ -17,10 +17,11 @@
...
@@ -17,10 +17,11 @@
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadC
=
true
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
...
@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
CShuffleEpilogue
,
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CDataType
,
kPad
A
,
kPad
M
,
kPad
B
,
kPad
N
,
kTilePermute
,
kTilePermute
,
kOutputRank
,
kOutputRank
,
1
,
1
,
...
@@ -65,13 +66,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -65,13 +66,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
TilePartitioner
::
kM
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPad
A
,
kPad
B
>>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPad
M
,
kPad
N
>>>
;
using
CodegenGemmTraits
=
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPad
A
,
kPad
B
,
kPad
C
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPad
M
,
kPad
N
,
kPad
K
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
<
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
using
CodegenGemmPipeline
=
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
...
...
example/ck_tile/03_gemm/gemm_mem_pipeline.cpp
View file @
e2a318bc
...
@@ -31,9 +31,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -31,9 +31,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPad
A
=
true
;
constexpr
bool
kPad
M
=
true
;
constexpr
bool
kPad
B
=
true
;
constexpr
bool
kPad
N
=
true
;
constexpr
bool
kPad
C
=
true
;
constexpr
bool
kPad
K
=
true
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
...
@@ -46,9 +46,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -46,9 +46,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPad
C
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPad
N
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPad
A
,
kPad
B
,
kPad
C
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPad
M
,
kPad
N
,
kPad
K
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
...
...
include/ck_tile/core/tensor/shuffle_tile.hpp
View file @
e2a318bc
...
@@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
...
@@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
}
}
else
else
{
{
// NOT implemented
static_assert
(
false
,
"The shuffle should always happen!"
);
}
}
}
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
e2a318bc
...
@@ -863,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -863,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
static_assert
(
N0
!=
0
);
static_assert
(
N0
!=
0
);
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
e2a318bc
...
@@ -115,12 +115,22 @@ struct GemmKernel
...
@@ -115,12 +115,22 @@ struct GemmKernel
}
}
}();
}();
auto
a_pad_view
=
pad_tensor_view
(
auto
a_pad_view
=
[
&
]()
{
a_tensor_view
,
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
// somehow clang-format is splitting below line into multiple.
return
pad_tensor_view
(
// clang-format off
a_tensor_view
,
sequence
<
false
,
GemmPipeline
::
kPadA
>
{});
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
// clang-format on
auto
a_block_window
=
make_tile_window
(
auto
a_block_window
=
make_tile_window
(
...
@@ -128,12 +138,22 @@ struct GemmKernel
...
@@ -128,12 +138,22 @@ struct GemmKernel
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
{
i_m
,
0
});
auto
b_pad_view
=
pad_tensor_view
(
auto
b_pad_view
=
[
&
]()
{
b_tensor_view
,
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
// clang-format off
return
pad_tensor_view
(
sequence
<
false
,
GemmPipeline
::
kPadB
>
{});
b_tensor_view
,
// clang-format on
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
auto
b_block_window
=
make_tile_window
(
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
b_pad_view
,
...
@@ -171,18 +191,28 @@ struct GemmKernel
...
@@ -171,18 +191,28 @@ struct GemmKernel
}
}
}();
}();
auto
c_pad_view
=
pad_tensor_view
(
auto
c_pad_view
=
[
&
]()
{
c_tensor_view
,
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
// clang-format off
return
pad_tensor_view
(
sequence
<
false
,
GemmPipeline
::
kPadC
>
{});
c_tensor_view
,
// clang-format on
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
auto
c_block_window
=
make_tile_window
(
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
{
i_m
,
i_n
});
EpiloguePipeline
{}(
c_b
lock
_w
indow
,
c_block_tile
);
EpiloguePipeline
{}(
CB
lock
W
indow
_pad
,
c_block_tile
);
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
e2a318bc
...
@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
// Where is the right place for HasHotLoop and TailNum ???
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
e2a318bc
...
@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
{
...
@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
// A LDS tile window for store
auto
a_copy_lds_window
=
auto
a_copy_lds_window
=
make_tile_window
(
make_tile_window
(
a_lds_block
,
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
// B DRAM tile window for load
auto
b_copy_dram_window
=
auto
b_copy_dram_window
=
...
@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
// B LDS tile window for store
auto
b_copy_lds_window
=
auto
b_copy_lds_window
=
make_tile_window
(
make_tile_window
(
b_lds_block
,
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
auto
a_lds_gemm_window
=
make_tile_window
(
...
@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
// LDS write 0
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockDescriptor
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
}
else
{
store_tile
(
a_copy_lds_window
,
tile_elementwise_in
(
a_element_func
,
a_block_tile
));
}
// LDS write 0
// LDS write 0
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
else
{
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_block_tile
));
}
}
}
index_t
iCounter
=
num_loop
-
1
;
index_t
iCounter
=
num_loop
-
1
;
...
@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// LDS write i + 1
// LDS write i + 1
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
}
else
{
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
iCounter
--
;
iCounter
--
;
}
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
e2a318bc
...
@@ -11,6 +11,7 @@ namespace ck_tile {
...
@@ -11,6 +11,7 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead
// Default policy class should not be templated, put template on member functions instead
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
{
#if 0
#if 0
// 2d
// 2d
template <typename Problem>
template <typename Problem>
...
@@ -116,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -116,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return
smem_size
;
return
smem_size
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
#elif 1
#elif 1
// fake XOR
// fake XOR
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -192,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -192,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
{
#if 1 // coalesce reading for each blocks
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
MPerBlock
/
M1
;
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
return
make_static_tile_distribution
(
static_assert
(
KPack
%
K3
==
0
);
tile_distribution_encoding
<
sequence
<
1
>
,
constexpr
index_t
K2
=
KPack
/
K3
;
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
))
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
{
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
sequence
<
1
,
2
>
,
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
sequence
<
0
,
1
>>
{});
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
#else // coalesce reading for each warps
return
make_static_tile_distribution
(
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
tile_distribution_encoding
<
sequence
<
1
>
,
constexpr
index_t
M1
=
kMPerBlock
/
(
M2
*
M0
);
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
return
make_static_tile_distribution
(
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tile_distribution_encoding
<
sequence
<
1
>
,
sequence
<
2
,
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
sequence
<
3
,
1
>>
{});
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
}
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
else
sequence
<
1
,
2
>
,
{
sequence
<
1
,
1
>>
{});
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
#endif
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
#if 1 // coalesce reading for each blocks
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
K3
=
total_pixels
/
N1
;
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
constexpr
index_t
kKPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
return
make_static_tile_distribution
(
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
tile_distribution_encoding
<
sequence
<
1
>
,
{
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
return
make_static_tile_distribution
(
sequence
<
0
,
1
>>
{});
tile_distribution_encoding
<
sequence
<
1
>
,
#else // coalesce reading for each warps
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
constexpr
index_t
N1
=
kNPerBlock
/
(
N2
*
N0
);
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
return
make_static_tile_distribution
(
sequence
<
1
,
3
>>
{});
tile_distribution_encoding
<
sequence
<
1
>
,
}
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
else
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
{
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
sequence
<
1
,
2
>
,
constexpr
index_t
K2_m
=
K2
/
K1
;
sequence
<
1
,
1
>>
{});
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
#endif
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
kMPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
e2a318bc
...
@@ -3,40 +3,133 @@
...
@@ -3,40 +3,133 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
static
constexpr
int
_VectorSize
=
16
;
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
TileGemmTraits_
>
typename
TileGemmTraits_
>
struct
GemmPipelineProblem
struct
GemmPipelineProblem
Base
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
VectorLoadSize
=
GemmTraits
::
_VectorSize
;
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
static
constexpr
bool
kPadM
=
GemmTraits
::
kPadM
;
static
constexpr
bool
kPadN
=
GemmTraits
::
kPadN
;
static
constexpr
bool
kPadK
=
GemmTraits
::
kPadK
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
pixels_per_thread
=
BlockGemmShape
::
kM
*
BlockGemmShape
::
kK
/
kBlockSize
;
return
pixels_per_thread
<
VectorLoadSize
/
sizeof
(
ADataType
)
?
pixels_per_thread
:
VectorLoadSize
/
sizeof
(
ADataType
);
}
else
{
return
VectorLoadSize
/
sizeof
(
ADataType
);
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentB
()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
pixels_per_thread
=
BlockGemmShape
::
kN
*
BlockGemmShape
::
kK
/
kBlockSize
;
return
pixels_per_thread
<
VectorLoadSize
/
sizeof
(
BDataType
)
?
pixels_per_thread
:
VectorLoadSize
/
sizeof
(
BDataType
);
}
else
{
return
VectorLoadSize
/
sizeof
(
BDataType
);
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentC
()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N2
=
std
::
min
(
BlockGemmShape
::
kN
/
N1
,
get_warp_size
());
constexpr
index_t
M0
=
get_warp_size
()
/
N2
;
constexpr
index_t
M1
=
BlockGemmShape
::
kM
/
M0
;
static
constexpr
index_t
VectorSizeA
=
kPadA
?
1
:
_VectorSize
/
sizeof
(
ADataType
);
return
std
::
min
(
M1
,
static_cast
<
index_t
>
(
VectorLoadSize
/
sizeof
(
CDataType
)));
static
constexpr
index_t
VectorSizeB
=
kPadB
?
1
:
_VectorSize
/
sizeof
(
BDataType
);
}
static
constexpr
index_t
VectorSizeC
=
kPadC
?
1
:
_VectorSize
/
sizeof
(
CDataType
);
else
{
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
std
::
min
(
BlockGemmShape
::
kM
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
BlockGemmShape
::
kN
/
N0
;
return
std
::
min
(
N1
,
static_cast
<
index_t
>
(
VectorLoadSize
/
sizeof
(
CDataType
)));
}
}
static
constexpr
index_t
VectorSizeA
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
kPadK
?
1
:
GetAlignmentA
();
}
else
{
return
kPadM
?
1
:
GetAlignmentA
();
}
}();
static
constexpr
index_t
VectorSizeB
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
kPadN
?
1
:
GetAlignmentB
();
}
else
{
return
kPadK
?
1
:
GetAlignmentB
();
}
}();
static
constexpr
index_t
VectorSizeC
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
kPadN
?
1
:
GetAlignmentC
();
}
else
{
return
kPadM
?
1
:
GetAlignmentC
();
}
}();
};
};
// Alias for GemmPipelineProblem
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
TileGemmTraits_
>
using
GemmPipelineProblem
=
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
;
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
...
@@ -45,30 +138,15 @@ template <typename ADataType_,
...
@@ -45,30 +138,15 @@ template <typename ADataType_,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
true
,
bool
HasHotLoop_
=
true
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
struct
UniversalGemmPipelineProblem
:
public
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
static
constexpr
auto
TailNum
=
TailNum_
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
static
constexpr
index_t
VectorSizeA
=
kPadA
?
_VectorSize
/
sizeof
(
ADataType
)
:
1
;
static
constexpr
index_t
VectorSizeB
=
kPadB
?
_VectorSize
/
sizeof
(
BDataType
)
:
1
;
static
constexpr
index_t
VectorSizeC
=
kPadC
?
_VectorSize
/
sizeof
(
CDataType
)
:
1
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
e2a318bc
...
@@ -9,12 +9,8 @@
...
@@ -9,12 +9,8 @@
namespace
ck_tile
{
namespace
ck_tile
{
// UniversalGemm Policy
// UniversalGemm Policy
template
<
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
UniversalGemmPipelineAgBgCrPolicy
struct
UniversalGemmPipelineAgBgCrPolicy
{
{
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
...
@@ -34,13 +30,14 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -34,13 +30,14 @@ struct UniversalGemmPipelineAgBgCrPolicy
TransposeC
>
;
TransposeC
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
Layout
A
>::
value
)
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
A
Layout
>::
value
)
{
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
?
1
?
1
...
@@ -176,13 +173,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -176,13 +173,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
Layout
B
>::
value
)
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
B
Layout
>::
value
)
{
{
// NLdsLayer * K0 as logical Bank
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
...
@@ -331,72 +330,285 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -331,72 +330,285 @@ struct UniversalGemmPipelineAgBgCrPolicy
return
smem_size
;
return
smem_size
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
typename
Problem
::
BDataType
,
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
constexpr
index_t
K0
=
KPerBlock
/
K1
;
{
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
total_pixels
%
M1
==
0
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
return
make_static_tile_distribution
(
constexpr
index_t
K2
=
KPack
/
K3
;
tile_distribution_encoding
<
sequence
<
1
>
,
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
)
==
0
)
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
{
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
sequence
<
1
,
2
>
,
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
sequence
<
0
,
1
>>
{});
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
typename
Problem
::
BDataType
,
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
constexpr
index_t
K0
=
KPerBlock
/
K1
;
{
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
total_pixels
%
N1
==
0
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
return
make_static_tile_distribution
(
constexpr
index_t
K2
=
KPack
/
K3
;
tile_distribution_encoding
<
sequence
<
1
>
,
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
{
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
sequence
<
1
,
2
>
,
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
sequence
<
0
,
1
>>
{});
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
e2a318bc
...
@@ -3,19 +3,23 @@
...
@@ -3,19 +3,23 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
bool
kPad
A
_
,
template
<
bool
kPad
M
_
,
bool
kPad
B
_
,
bool
kPad
N
_
,
bool
kPad
C
_
,
bool
kPad
K
_
,
typename
ALayout_
,
typename
ALayout_
,
typename
BLayout_
,
typename
BLayout_
,
typename
CLayout_
>
typename
CLayout_
>
struct
TileGemmTraits
struct
TileGemmTraits
{
{
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
int
_VectorSize
=
16
;
using
ALayout
=
ALayout_
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
BLayout
=
BLayout_
;
...
...
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
e2a318bc
...
@@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs
...
@@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// x row_stride
index_t
xr_stride
;
// x residule row stride
index_t
y_stride
;
// y row stride
index_t
yr_stride
;
// y residule row stride
};
};
// TODO: Extract some type to wrapper class
// TODO: Extract some type to wrapper class
...
@@ -93,7 +96,10 @@ struct Layernorm2dFwd
...
@@ -93,7 +96,10 @@ struct Layernorm2dFwd
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// x row_stride
index_t
xr_stride
;
// x residule row stride
index_t
y_stride
;
// y row stride
index_t
yr_stride
;
// y residule row stride
};
};
using
Hargs
=
Layernorm2dFwdHostArgs
;
using
Hargs
=
Layernorm2dFwdHostArgs
;
...
@@ -112,7 +118,10 @@ struct Layernorm2dFwd
...
@@ -112,7 +118,10 @@ struct Layernorm2dFwd
hargs
.
epsilon
,
hargs
.
epsilon
,
hargs
.
m
,
hargs
.
m
,
hargs
.
n
,
hargs
.
n
,
hargs
.
stride
};
hargs
.
x_stride
,
hargs
.
xr_stride
,
hargs
.
y_stride
,
hargs
.
yr_stride
};
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
...
@@ -182,7 +191,7 @@ struct Layernorm2dFwd
...
@@ -182,7 +191,7 @@ struct Layernorm2dFwd
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
x_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -201,7 +210,7 @@ struct Layernorm2dFwd
...
@@ -201,7 +210,7 @@ struct Layernorm2dFwd
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XResidualDataType
*>
(
kargs
.
p_x_residual
),
static_cast
<
const
XResidualDataType
*>
(
kargs
.
p_x_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
xr_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -250,7 +259,7 @@ struct Layernorm2dFwd
...
@@ -250,7 +259,7 @@ struct Layernorm2dFwd
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
y_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -266,7 +275,7 @@ struct Layernorm2dFwd
...
@@ -266,7 +275,7 @@ struct Layernorm2dFwd
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YResidualDataType
*>
(
kargs
.
p_y_residual
),
static_cast
<
YResidualDataType
*>
(
kargs
.
p_y_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
yr_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
...
script/process_perf_data.py
View file @
e2a318bc
...
@@ -133,12 +133,12 @@ def parse_logfile(logfile):
...
@@ -133,12 +133,12 @@ def parse_logfile(logfile):
if
'Best Perf'
in
line
:
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
lst
=
line
.
split
()
res
.
append
(
lst
[
4
])
res
.
append
(
lst
[
4
])
elif
'onnx_gemm'
in
logfile
or
'mixed_gemm'
in
logfile
:
elif
'onnx_gemm'
in
logfile
:
for
line
in
open
(
logfile
):
for
line
in
open
(
logfile
):
if
'Best Perf'
in
line
:
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
lst
=
line
.
split
()
res
.
append
(
lst
[
33
])
res
.
append
(
lst
[
33
])
elif
'splitK_gemm'
in
logfile
:
elif
'splitK_gemm'
in
logfile
or
'mixed_gemm'
in
logfile
:
for
line
in
open
(
logfile
):
for
line
in
open
(
logfile
):
if
'Best Perf'
in
line
:
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
lst
=
line
.
split
()
...
...
script/process_qa_data.sh
View file @
e2a318bc
...
@@ -22,6 +22,7 @@ python3 process_perf_data.py perf_gemm_bilinear.log
...
@@ -22,6 +22,7 @@ python3 process_perf_data.py perf_gemm_bilinear.log
python3 process_perf_data.py perf_reduction.log
python3 process_perf_data.py perf_reduction.log
python3 process_perf_data.py perf_splitK_gemm.log
python3 process_perf_data.py perf_splitK_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_mixed_gemm.log
file
=
./perf_fmha_fwd_gfx942.log
file
=
./perf_fmha_fwd_gfx942.log
if
[
-e
"
$file
"
]
;
then
if
[
-e
"
$file
"
]
;
then
...
...
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
View file @
e2a318bc
...
@@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
bool
kPad
A
=
true
;
constexpr
bool
kPad
M
=
true
;
constexpr
bool
kPad
B
=
true
;
constexpr
bool
kPad
N
=
true
;
constexpr
bool
kPad
C
=
true
;
constexpr
bool
kPad
K
=
true
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
...
@@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPad
C
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPad
N
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPad
A
,
kPad
B
,
kPad
C
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPad
M
,
kPad
N
,
kPad
K
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
...
@@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
{
{
std
::
cout
<<
"Lunching kernel with args:"
std
::
cout
<<
"L
a
unching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
<<
"}"
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment