Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d51f4e52
Commit
d51f4e52
authored
Nov 22, 2024
by
dummycoderfe
Browse files
use 32x32x8 ok, fix scratch store
parent
bc4366d4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
29 additions
and
22 deletions
+29
-22
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+1
-1
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+1
-1
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+14
-13
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+9
-1
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
...m/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
+2
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+2
-5
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
d51f4e52
...
...
@@ -72,7 +72,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
Universal
GemmPipelineA
gBgCr
Policy
;
using
CodegenGemmPolicy
=
ck_tile
::
GemmPipelineA
GmemBGmemCRegV1Default
Policy
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
d51f4e52
...
...
@@ -73,7 +73,7 @@ auto create_args(int argc, char* argv[])
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"k"
,
"2048"
,
"k dimension"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"
R
"
,
"B tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"
C
"
,
"B tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
d51f4e52
...
...
@@ -194,22 +194,23 @@ int run_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("
a_layout
");
std::string b_layout = arg_parser.get_str("
b_layout
");
if(a_layout == "
R
" && b_layout == "
R
")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "
R
" && b_layout == "
C
")
// if(a_layout == "
R
" && b_layout == "
R
")
// {
// return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
// }
// else
if(a_layout == "
R
" && b_layout == "
C
")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "
C
" && b_layout == "
C
")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "
C
" && b_layout == "
R
")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
}
//
else if(a_layout == "
C
" && b_layout == "
C
")
//
{
//
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
//
}
//
else if(a_layout == "
C
" && b_layout == "
R
")
//
{
//
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
//
}
else
{
throw std::runtime_error("
Unsupported
data
layout
configuration
for
A
,
B
and
C
tensors
!
");
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
d51f4e52
...
...
@@ -42,6 +42,9 @@ struct BlockGemmASmemBSmemCRegV1
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
// if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MPerBlock %d NPerBlock %d KPerBlock %d \n", MPerBlock, NPerBlock, KPerBlock);
// }
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
...
...
@@ -60,6 +63,12 @@ struct BlockGemmASmemBSmemCRegV1
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
// if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter);
// }
// MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8
// construct A-warp-window
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
...
...
@@ -136,7 +145,6 @@ struct BlockGemmASmemBSmemCRegV1
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
View file @
d51f4e52
...
...
@@ -40,7 +40,8 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
#else
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K8
{},
2
,
2
);
// return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
d51f4e52
...
...
@@ -112,10 +112,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
smem_size
+=
smem_size_a
+
smem_size_b
;
return
smem_size
;
return
smem_size_a
+
smem_size_b
;
}
template
<
typename
Problem
>
...
...
@@ -259,7 +256,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
{
//Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment