Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cca67d13
Commit
cca67d13
authored
Jan 23, 2025
by
ThomasNing
Browse files
Finished the coding of the feature, Compiler not in the way we supposed to have
parent
3e0047a6
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
640 additions
and
18 deletions
+640
-18
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+3
-0
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+16
-3
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+2
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+7
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+1
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
+196
-10
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
+413
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+1
-1
No files found.
example/ck_tile/03_gemm/CMakeLists.txt
View file @
cca67d13
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp
)
add_executable
(
tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp
)
target_compile_options
(
tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
cca67d13
...
@@ -29,9 +29,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -29,9 +29,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE || \
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
// Compute friendly for Intrawave scheduler
// Compute friendly for Intrawave scheduler
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
...
@@ -44,6 +42,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -44,6 +42,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
#endif
#endif
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
...
...
include/ck_tile/ops/gemm.hpp
View file @
cca67d13
...
@@ -36,6 +36,7 @@
...
@@ -36,6 +36,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
cca67d13
...
@@ -436,7 +436,8 @@ struct GemmKernel
...
@@ -436,7 +436,8 @@ struct GemmKernel
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
[
&
]()
{
const
auto
&
c_block_tile
=
[
&
]()
{
if
constexpr
(
GemmPipeline
::
isDoubleSmemBuffer
==
true
)
if
constexpr
(
GemmPipeline
::
isDoubleSmemBuffer
==
true
)
{
{
__shared__
char
smem_ptr_1
[
GetSmemSize
()];
__shared__
char
smem_ptr_1
[
GetSmemSize
()];
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
View file @
cca67d13
...
@@ -35,6 +35,13 @@ struct GemmPipelineAgBgCrImplBase
...
@@ -35,6 +35,13 @@ struct GemmPipelineAgBgCrImplBase
store_tile
(
lds_tile_window
,
block_tile_tmp
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
DstBlockTile
&
dst_block_tile
,
const
SrcTileWindow
&
lds_tile_window
)
const
{
load_tile
(
dst_block_tile
,
lds_tile_window
);
}
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
{
{
// A tile in LDS
// A tile in LDS
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
cca67d13
...
@@ -77,8 +77,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -77,8 +77,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
@@ -339,7 +337,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -339,7 +337,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// tail
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
{
block_gemm
(
c_block_tile
,
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
}
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// latency
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
View file @
cca67d13
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_
v1_default
_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_
compute_v4
_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemC
R
eg
V1
DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemC
r
eg
ComputeV4
DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV4
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
struct
GemmPipelineAgBgCrCompV4
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
{
{
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
...
@@ -45,6 +45,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -45,6 +45,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
bool
isDoubleSmemBuffer
=
Problem
::
isDoubleSmemBuffer
;
static
constexpr
bool
isDoubleSmemBuffer
=
Problem
::
isDoubleSmemBuffer
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
@@ -60,6 +64,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -60,6 +64,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
template
<
>
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
{
using
Base
=
PipelineImplBase
;
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
{
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I0
{});
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I0
{});
...
@@ -119,7 +125,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -119,7 +125,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
template
<
typename
ADramBlockWindowTmp
,
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
AElementFunction
,
typename
BElementFunction
>
typename
BElementFunction
>
...
@@ -128,8 +136,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -128,8 +136,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
index_t
num_loop
,
void
*
__restrict__
p_smem_0
,
void
*
p_smem_0
,
void
*
__restrict__
p_smem_1
)
void
*
p_smem_1
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
...
@@ -188,13 +196,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -188,13 +196,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
auto
b_copy_lds_window0
=
auto
b_copy_lds_window0
=
make_tile_window
(
b_lds_block0
,
make_tile_window
(
b_lds_block0
,
make_tuple
(
number
<
k
NPerBlock
>
{},
number
<
k
KPerBlock
>
{}),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
{
0
,
0
},
BBlockTileDistr
);
BBlockTileDistr
);
auto
b_copy_lds_window1
=
auto
b_copy_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
k
NPerBlock
>
{},
number
<
k
KPerBlock
>
{}),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
{
0
,
0
},
BBlockTileDistr
);
BBlockTileDistr
);
...
@@ -213,10 +221,188 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -213,10 +221,188 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_sync_lds
();
block_sync_lds
();
block_gemm
.
LocalPrefetch
();
constexpr
auto
ALdsTileDistr
=
decltype
(
make_static_tile_distribution
(
BlockGemm
::
MakeABlockDistributionEncode
())){};
constexpr
auto
BLdsTileDistr
=
decltype
(
make_static_tile_distribution
(
BlockGemm
::
MakeBBlockDistributionEncode
())){};
using
ALdsTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ALdsTileDistr
));
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
));
ALdsTile
a_block_tile0
;
ALdsTile
a_block_tile1
;
BLdsTile
b_block_tile0
;
BLdsTile
b_block_tile1
;
auto
a_lds_ld_window0
=
make_tile_window_linear
(
a_lds_block0
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
a_lds_ld_window1
=
make_tile_window_linear
(
a_lds_block1
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
b_lds_ld_window0
=
make_tile_window_linear
(
b_lds_block0
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BLdsTileDistr
);
auto
b_lds_ld_window1
=
make_tile_window_linear
(
b_lds_block1
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BLdsTileDistr
);
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
if
(
HasHotLoop
)
{
// minus 2 because we have ping-pong double buffer.
index_t
iCounter
=
__builtin_amdgcn_readfirstlane
(
num_loop
-
2
);
do
{
// ping
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// pong
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
iCounter
-=
2
;
}
while
(
iCounter
>
1
);
}
// tail 3
if
(
TailNum
==
TailNumber
::
Three
)
{
// 3
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
// 2
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
// 1
{
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
if
(
TailNum
==
TailNumber
::
Two
)
{
// 2
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
8
,
0
);
// MFMA
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// 1
{
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
// when tail num is one
{
{
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
return
c_block_tile
;
}
}
};
};
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem_0
,
void
*
p_smem_1
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem_0
,
p_smem_1
);
}
public:
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
index_t
num_loop
,
void
*
p_smem_0
,
void
*
p_smem_1
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem_0
,
p_smem_1
);
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
0 → 100644
View file @
cca67d13
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
cca67d13
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment