Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cca67d13
Commit
cca67d13
authored
Jan 23, 2025
by
ThomasNing
Browse files
Finished the coding of the feature, Compiler not in the way we supposed to have
parent
3e0047a6
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
640 additions
and
18 deletions
+640
-18
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+3
-0
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+16
-3
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+2
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+7
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+1
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
+196
-10
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
+413
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+1
-1
No files found.
example/ck_tile/03_gemm/CMakeLists.txt
View file @
cca67d13
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp
)
target_compile_options
(
tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
cca67d13
...
...
@@ -29,9 +29,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE || \
CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
// Compute friendly for Intrawave scheduler
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
...
...
@@ -44,6 +42,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
#endif
constexpr
bool
kPadM
=
false
;
...
...
include/ck_tile/ops/gemm.hpp
View file @
cca67d13
...
...
@@ -36,6 +36,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
cca67d13
...
...
@@ -436,7 +436,8 @@ struct GemmKernel
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
[
&
]()
{
const
auto
&
c_block_tile
=
[
&
]()
{
if
constexpr
(
GemmPipeline
::
isDoubleSmemBuffer
==
true
)
{
__shared__
char
smem_ptr_1
[
GetSmemSize
()];
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
View file @
cca67d13
...
...
@@ -35,6 +35,13 @@ struct GemmPipelineAgBgCrImplBase
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
DstBlockTile
&
dst_block_tile
,
const
SrcTileWindow
&
lds_tile_window
)
const
{
load_tile
(
dst_block_tile
,
lds_tile_window
);
}
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
{
// A tile in LDS
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
cca67d13
...
...
@@ -77,8 +77,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
...
@@ -339,7 +337,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_gemm
(
c_block_tile
,
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
View file @
cca67d13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_
v1_default
_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_
compute_v4
_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemC
R
eg
V1
DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemC
r
eg
ComputeV4
DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV4
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
...
...
@@ -45,6 +45,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
bool
isDoubleSmemBuffer
=
Problem
::
isDoubleSmemBuffer
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
...
@@ -60,6 +64,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
using
Base
=
PipelineImplBase
;
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I0
{});
...
...
@@ -119,7 +125,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
typename
ADramBlockWindowTmp
,
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
...
...
@@ -128,8 +136,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
__restrict__
p_smem_0
,
void
*
__restrict__
p_smem_1
)
void
*
p_smem_0
,
void
*
p_smem_1
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
...
...
@@ -188,13 +196,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
auto
b_copy_lds_window0
=
make_tile_window
(
b_lds_block0
,
make_tuple
(
number
<
k
NPerBlock
>
{},
number
<
k
KPerBlock
>
{}),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BBlockTileDistr
);
auto
b_copy_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
k
NPerBlock
>
{},
number
<
k
KPerBlock
>
{}),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BBlockTileDistr
);
...
...
@@ -213,10 +221,188 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_sync_lds
();
block_gemm
.
LocalPrefetch
();
constexpr
auto
ALdsTileDistr
=
decltype
(
make_static_tile_distribution
(
BlockGemm
::
MakeABlockDistributionEncode
())){};
constexpr
auto
BLdsTileDistr
=
decltype
(
make_static_tile_distribution
(
BlockGemm
::
MakeBBlockDistributionEncode
())){};
using
ALdsTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ALdsTileDistr
));
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
));
ALdsTile
a_block_tile0
;
ALdsTile
a_block_tile1
;
BLdsTile
b_block_tile0
;
BLdsTile
b_block_tile1
;
auto
a_lds_ld_window0
=
make_tile_window_linear
(
a_lds_block0
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
a_lds_ld_window1
=
make_tile_window_linear
(
a_lds_block1
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
b_lds_ld_window0
=
make_tile_window_linear
(
b_lds_block0
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BLdsTileDistr
);
auto
b_lds_ld_window1
=
make_tile_window_linear
(
b_lds_block1
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BLdsTileDistr
);
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
if
(
HasHotLoop
)
{
// minus 2 because we have ping-pong double buffer.
index_t
iCounter
=
__builtin_amdgcn_readfirstlane
(
num_loop
-
2
);
do
{
// ping
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// pong
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
iCounter
-=
2
;
}
while
(
iCounter
>
1
);
}
// tail 3
if
(
TailNum
==
TailNumber
::
Three
)
{
// 3
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
// 2
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
// 1
{
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
if
(
TailNum
==
TailNumber
::
Two
)
{
// 2
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
8
,
0
);
// MFMA
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// 1
{
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
// when tail num is one
{
{
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
return
c_block_tile
;
}
};
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem_0
,
void
*
p_smem_1
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem_0
,
p_smem_1
);
}
public:
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
index_t
num_loop
,
void
*
p_smem_0
,
void
*
p_smem_1
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem_0
,
p_smem_1
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
0 → 100644
View file @
cca67d13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace
ck_tile
{
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
{
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
bool
TransposeC
=
true
;
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
namespace
ck_tile
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
a_lds_block_desc
=
transform_tensor_descriptor
(
a_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kMPerBlock
),
make_merge_transform
(
make_tuple
(
kKPerBlock
/
8
,
8
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc
;
}
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kNPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kNPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kNPerBlock
),
make_merge_transform
(
make_tuple
(
kKPerBlock
/
8
,
8
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
constexpr
index_t
smem_size
=
smem_size_a
+
smem_size_b
;
return
smem_size
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
))
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
static_assert
(
N0
*
N1
*
N2
==
NPerBlock
,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
static_assert
(
N0
*
N1
*
N2
==
NPerBlock
,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
kMPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
TransposeC
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
cca67d13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment