Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c6eb7c30
Commit
c6eb7c30
authored
Jan 17, 2025
by
ThomasNing
Browse files
comp v4 setup
parent
5d671a5f
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
260 additions
and
5 deletions
+260
-5
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+9
-1
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+5
-2
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+14
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
+222
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+3
-0
No files found.
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
c6eb7c30
...
...
@@ -13,19 +13,27 @@
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V2 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
_V2
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
constexpr
bool
isDoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
constexpr
bool
isDoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
constexpr
bool
isDoubleSmemBuffer
=
true
;
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
c6eb7c30
...
...
@@ -29,7 +29,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE || \
CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
// Compute friendly for Intrawave scheduler
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
...
...
@@ -61,7 +63,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
isDoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
...
...
include/ck_tile/ops/gemm.hpp
View file @
c6eb7c30
...
...
@@ -29,6 +29,7 @@
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
c6eb7c30
...
...
@@ -435,8 +435,20 @@ struct GemmKernel
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
const
auto
&
c_block_tile
=
[
&
]()
{
if
constexpr
(
GemmPipeline
::
isDoubleSmemBuffer
==
true
)
{
__shared__
char
smem_ptr_1
[
GetSmemSize
()];
return
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
,
smem_ptr_1
);
}
else
{
return
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
}
}();
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
c6eb7c30
...
...
@@ -70,6 +70,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
isDoubleSmemBuffer
=
Problem
::
isDoubleSmemBuffer
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
0 → 100644
View file @
c6eb7c30
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV4
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
isDoubleSmemBuffer
=
Problem
::
isDoubleSmemBuffer
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I0
{});
constexpr
index_t
NPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I1
{});
constexpr
index_t
KPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I2
{});
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeA
);
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeB
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst
=
num_ds_read_inst_a
+
num_ds_read_inst_b
;
constexpr
auto
num_ds_write_inst
=
A_LDS_Write_Inst_Num
+
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst
=
A_Buffer_Load_Inst_Num
+
B_Buffer_Load_Inst_Num
;
static_for
<
0
,
num_buffer_load_inst
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
/
num_buffer_load_inst
,
0
);
// DS read : 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA: 1
__builtin_amdgcn_sched_group_barrier
(
0x200
,
num_ds_write_inst
/
num_buffer_load_inst
,
0
);
// DS write : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read :1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
C_MFMA_Inst_Num
/
num_buffer_load_inst
-
3
,
0
);
// MFMA : 5
});
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
__restrict__
p_smem_0
,
void
*
__restrict__
p_smem_1
)
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
////////////// global window & register /////////////////
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window_linear
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window_linear
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// A register tile for global load
constexpr
auto
ABlockTileDistr
=
a_copy_dram_window
.
get_tile_distribution
();
constexpr
auto
BBlockTileDistr
=
b_copy_dram_window
.
get_tile_distribution
();
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
));
ABlockTile
a_global_load_tile
;
BBlockTile
b_global_load_tile
;
// global prefetch 0
// global read 0
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
////////////// LDS desc, window & register /////////////////
auto
&&
[
a_lds_block0
,
b_lds_block0
]
=
Base
::
GetABLdsTensorViews
(
p_smem_0
);
auto
&&
[
a_lds_block1
,
b_lds_block1
]
=
Base
::
GetABLdsTensorViews
(
p_smem_1
);
auto
a_copy_lds_window0
=
make_tile_window
(
a_lds_block0
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
ABlockTileDistr
);
auto
a_copy_lds_window1
=
make_tile_window
(
a_lds_block1
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
ABlockTileDistr
);
auto
b_copy_lds_window0
=
make_tile_window
(
b_lds_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
BBlockTileDistr
);
auto
b_copy_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
BBlockTileDistr
);
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
// global read 1
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_sync_lds
();
block_gemm
.
LocalPrefetch
();
}
};
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
c6eb7c30
...
...
@@ -121,6 +121,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
isDoubleSmemBuffer
=
Problem
::
isDoubleSmemBuffer
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
c6eb7c30
...
...
@@ -33,6 +33,8 @@ struct GemmPipelineProblemBase
static
constexpr
bool
kPadN
=
GemmTraits
::
kPadN
;
static
constexpr
bool
kPadK
=
GemmTraits
::
kPadK
;
static
constexpr
bool
isDoubleSmemBuffer
=
GemmTraits
::
isDoubleSmemBuffer
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
c6eb7c30
...
...
@@ -10,6 +10,7 @@ namespace ck_tile {
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
kPadK_
,
bool
isDoubleSmemBuffer_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
>
...
...
@@ -19,6 +20,8 @@ struct TileGemmTraits
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
bool
isDoubleSmemBuffer
=
isDoubleSmemBuffer_
;
static
constexpr
int
_VectorSize
=
16
;
using
ALayout
=
ALayout_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment