Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
34612efd
Commit
34612efd
authored
Feb 07, 2025
by
ThomasNing
Browse files
address the new comments
parent
7409674a
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
133 additions
and
135 deletions
+133
-135
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+1
-1
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+45
-20
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
+16
-12
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp
...ipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp
+2
-46
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+35
-35
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+1
-1
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+30
-17
No files found.
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
34612efd
...
...
@@ -16,7 +16,7 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V
3
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V
4
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
...
...
include/ck_tile/ops/gemm.hpp
View file @
34612efd
...
...
@@ -30,9 +30,9 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
View file @
34612efd
...
...
@@ -14,26 +14,51 @@ namespace ck_tile {
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBRegCRegV1DefaultPolicy
>
struct
BlockGemmARegBRegCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WarpGemm
::
kK
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
private:
template
<
typename
PipelineProblem_
,
typename
GemmPolicy_
>
struct
GemmTraits_
{
using
Problem
=
remove_cvref_t
<
PipelineProblem_
>
;
using
Policy
=
remove_cvref_t
<
GemmPolicy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WarpGemm
::
kK
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
};
public:
using
Traits
=
GemmTraits_
<
Problem_
,
Policy_
>
;
using
WarpGemm
=
typename
Traits
::
WarpGemm
;
using
ADataType
=
remove_cvref_t
<
typename
Traits
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Traits
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Traits
::
CDataType
>
;
static
constexpr
index_t
KIterPerWarp
=
Traits
::
KIterPerWarp
;
static
constexpr
index_t
MIterPerWarp
=
Traits
::
MIterPerWarp
;
static
constexpr
index_t
NIterPerWarp
=
Traits
::
NIterPerWarp
;
static
constexpr
index_t
MWarp
=
Traits
::
MWarp
;
static
constexpr
index_t
NWarp
=
Traits
::
NWarp
;
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistributionEncode
()
{
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
34612efd
...
...
@@ -520,8 +520,8 @@ struct GemmKernel
CK_TILE_DEVICE
static
void
RunGemm2LDS
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
void
*
smem_ptr_0
,
void
*
smem_ptr_1
,
void
*
__restrict__
smem_ptr_0
,
void
*
__restrict__
smem_ptr_1
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
View file @
34612efd
...
...
@@ -4,8 +4,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
namespace
ck_tile
{
...
...
@@ -37,18 +36,23 @@ struct BaseGemmPipelineAgBgCrCompV4
}
};
// Compute optimized pipeline version 4
// The difference between this pipeline and compute version 3 is it has two LDS window that will use
// the ping-pong buffer to grab memory from the global memory. While one LDS is grabbing the data
// from global memory, the other will call the warps on running the MFMA matrix multiplication. When
// the matrix is in bigger shape, it will keep the Warp always busy and cover the memory loading
// time. It will have better performance comparing to the Compute Version 3 when they have the same
// block tile and better performance when you have M, N, K all > 8K even when the compute V3 block
// size is 2 times of the compute V4.
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
>
/**
* @brief Compute optimized pipeline version 4
*
* This version introduces a dual LDS window mechanism using a ping-pong buffer approach
* for more efficient data handling from global memory. Unlike compute version 3, this method
* allows one LDS to fetch data from global memory while the other LDS executes warps for MFMA
* matrix multiplication. This dual operation helps in keeping the Warp unit continuously busy,
* thereby significantly reducing memory load times and enhancing overall performance.
*
* @note This version shows improved performance over Compute Version 3 with the same block tile.
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
* even when Compute Version 3's block size is twice that of Compute Version 4.
*/
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAgBgCrCompV4DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV4
:
public
BaseGemmPipelineAgBgCrCompV4
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrCompV
3
<
Problem
>
;
using
Base
=
BaseGemmPipelineAgBgCrCompV
4
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag
mem_bgmem
_cr
eg
_comp
ute
_v4_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag
_bg
_cr_comp_v4
_default
_policy.hpp
View file @
34612efd
...
...
@@ -14,24 +14,9 @@ namespace ck_tile {
// UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on
// member functions instead.
struct
GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
:
public
UniversalGemmBasePolicy
struct
GemmPipelineAgBgCrCompV4DefaultPolicy
:
public
UniversalGemmBasePolicy
<
GemmPipelineAgBgCrCompV4DefaultPolicy
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
BlockGemm
=
decltype
(
GetBlockGemm
<
Problem
>
());
constexpr
index_t
KPack
=
BlockGemm
::
KPack
;
return
KPack
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BlockGemm
=
decltype
(
GetBlockGemm
<
Problem
>
());
constexpr
index_t
KPack
=
BlockGemm
::
KPack
;
return
KPack
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
...
...
@@ -82,35 +67,6 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy : public UniversalGemmBa
return
b_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
16
);
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
16
);
return
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
return
smem_size_a
+
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
34612efd
...
...
@@ -9,6 +9,7 @@
namespace
ck_tile
{
template
<
typename
Derived
>
struct
UniversalGemmBasePolicy
{
static
constexpr
auto
I0
=
number
<
0
>
{};
...
...
@@ -270,15 +271,11 @@ struct UniversalGemmBasePolicy
BTileAccessPattern
>
;
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
}
};
// UniversalGemm Policy
struct
UniversalGemmPipelineAgBgCrPolicy
:
public
UniversalGemmBasePolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
BlockGemm
=
decltype
(
GetBlockGemm
<
Problem
>
());
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Derived
::
template
GetBlockGemm
<
Problem
>())
>
;
constexpr
index_t
KPack
=
BlockGemm
::
Traits
::
KPack
;
return
KPack
;
}
...
...
@@ -286,11 +283,43 @@ struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BlockGemm
=
decltype
(
GetBlockGemm
<
Problem
>
());
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Derived
::
template
GetBlockGemm
<
Problem
>())
>
;
constexpr
index_t
KPack
=
BlockGemm
::
Traits
::
KPack
;
return
KPack
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
auto
a_lds_desc
=
Derived
::
template
MakeALdsBlockDescriptor
<
Problem
>();
constexpr
index_t
smem_size_a
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
ADataType
)
*
a_lds_desc
.
get_element_space_size
(),
16
);
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
auto
b_lds_desc
=
Derived
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
constexpr
index_t
smem_size_b
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
BDataType
)
*
b_lds_desc
.
get_element_space_size
(),
16
);
return
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
return
smem_size_a
+
smem_size_b
;
}
};
// UniversalGemm Policy
struct
UniversalGemmPipelineAgBgCrPolicy
:
public
UniversalGemmBasePolicy
<
UniversalGemmPipelineAgBgCrPolicy
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
...
...
@@ -531,35 +560,6 @@ struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy
#endif
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
16
);
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
16
);
return
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
return
smem_size_a
+
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
...
...
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
34612efd
...
...
@@ -17,7 +17,7 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// ck_tile::GemmPipelineScheduler::Interwave>;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
//
using CompV4
= ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using
CompV4
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
CompV4
>
;
using
CompV3
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
CompV3
>
;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
34612efd
...
...
@@ -18,6 +18,30 @@ enum struct GemmPipelineType
CompV4
};
template
<
GemmPipelineType
PT
,
typename
Problem
>
struct
GemmPipelineTypeSelector
;
template
<
typename
Problem
>
struct
GemmPipelineTypeSelector
<
GemmPipelineType
::
Mem
,
Problem
>
{
using
base_pipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
pipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
Problem
>
;
};
template
<
typename
Problem
>
struct
GemmPipelineTypeSelector
<
GemmPipelineType
::
CompV3
,
Problem
>
{
using
base_pipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
pipeline
=
ck_tile
::
GemmPipelineAgBgCrCompV3
<
Problem
>
;
};
template
<
typename
Problem
>
struct
GemmPipelineTypeSelector
<
GemmPipelineType
::
CompV4
,
Problem
>
{
using
base_pipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV4
<
Problem
>
;
using
pipeline
=
ck_tile
::
GemmPipelineAgBgCrCompV4
<
Problem
>
;
};
template
<
typename
Tuple
>
class
TestCkTileGemmPipeline
:
public
::
testing
::
Test
{
...
...
@@ -37,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
void
invoke_gemm
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// TODO: This should be parameterized in tests
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
...
...
@@ -84,12 +108,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
using
BaseGemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
GemmPipelineProblem
>
,
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
CompV3
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
GemmPipelineProblem
>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV4
<
GemmPipelineProblem
>>>
;
using
BaseGemmPipeline
=
typename
GemmPipelineTypeSelector
<
PipelineType
,
GemmPipelineProblem
>::
base_pipeline
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
...
...
@@ -110,15 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v
,
tail_number_v
>
;
using
GemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
GemmPipelineAgBgCrMem
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
,
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
CompV3
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
,
ck_tile
::
GemmPipelineAgBgCrCompV4
<
UniversalGemmProblem
>>>
;
using
GemmPipeline
=
typename
GemmPipelineTypeSelector
<
PipelineType
,
UniversalGemmProblem
>::
pipeline
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment