Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b2c7d774
Commit
b2c7d774
authored
Jan 31, 2025
by
ThomasNing
Browse files
Add the changes from include/ck_tile
parent
d1e71770
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
480 additions
and
359 deletions
+480
-359
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+0
-1
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-1
include/ck_tile/ops/batched_transpose.hpp
include/ck_tile/ops/batched_transpose.hpp
+1
-1
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+17
-15
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+4
-4
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+0
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
+201
-60
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
+249
-271
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+3
-1
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+3
-2
No files found.
cmake/EnableCompilerWarnings.cmake
View file @
b2c7d774
...
@@ -66,7 +66,6 @@ else()
...
@@ -66,7 +66,6 @@ else()
-Wunreachable-code
-Wunreachable-code
-Wunused
-Wunused
-Wno-reserved-identifier
-Wno-reserved-identifier
-v --save-temps -Wno-gnu-line-marke
-Werror
-Werror
-Wno-option-ignored
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
...
...
include/ck_tile/host.hpp
View file @
b2c7d774
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
...
@@ -34,4 +35,3 @@
...
@@ -34,4 +35,3 @@
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
#include "ck_tile/host/timer.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
include/ck_tile/ops/batched_transpose.hpp
View file @
b2c7d774
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/gemm.hpp
View file @
b2c7d774
...
@@ -32,11 +32,11 @@
...
@@ -32,11 +32,11 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
View file @
b2c7d774
...
@@ -26,12 +26,14 @@ struct BlockGemmARegBRegCRegV1
...
@@ -26,12 +26,14 @@ struct BlockGemmARegBRegCRegV1
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
W
G
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
W
arpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WarpGemm
::
kK
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistributionEncode
()
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistributionEncode
()
{
{
...
@@ -43,7 +45,7 @@ struct BlockGemmARegBRegCRegV1
...
@@ -43,7 +45,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
W
G
::
AWarpDstrEncoding
{});
a_block_outer_dstr_encoding
,
typename
W
arpGemm
::
AWarpDstrEncoding
{});
return
a_block_dstr_encode
;
return
a_block_dstr_encode
;
}
}
...
@@ -58,7 +60,7 @@ struct BlockGemmARegBRegCRegV1
...
@@ -58,7 +60,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
W
G
::
BWarpDstrEncoding
{});
b_block_outer_dstr_encoding
,
typename
W
arpGemm
::
BWarpDstrEncoding
{});
return
b_block_dstr_encode
;
return
b_block_dstr_encode
;
}
}
...
@@ -73,7 +75,7 @@ struct BlockGemmARegBRegCRegV1
...
@@ -73,7 +75,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
W
G
::
CWarpDstrEncoding
{});
c_block_outer_dstr_encoding
,
typename
W
arpGemm
::
CWarpDstrEncoding
{});
return
c_block_dstr_encode
;
return
c_block_dstr_encode
;
}
}
...
@@ -112,13 +114,13 @@ struct BlockGemmARegBRegCRegV1
...
@@ -112,13 +114,13 @@ struct BlockGemmARegBRegCRegV1
.
get_static_tile_distribution_encoding
())
>>
,
.
get_static_tile_distribution_encoding
())
>>
,
"C distribution is wrong!"
);
"C distribution is wrong!"
);
using
AWarpDstr
=
typename
W
G
::
AWarpDstr
;
using
AWarpDstr
=
typename
W
arpGemm
::
AWarpDstr
;
using
BWarpDstr
=
typename
W
G
::
BWarpDstr
;
using
BWarpDstr
=
typename
W
arpGemm
::
BWarpDstr
;
using
CWarpDstr
=
typename
W
G
::
CWarpDstr
;
using
CWarpDstr
=
typename
W
arpGemm
::
CWarpDstr
;
using
AWarpTensor
=
typename
W
G
::
AWarpTensor
;
using
AWarpTensor
=
typename
W
arpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
W
G
::
BWarpTensor
;
using
BWarpTensor
=
typename
W
arpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
W
G
::
CWarpTensor
;
using
CWarpTensor
=
typename
W
arpGemm
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
...
@@ -157,7 +159,7 @@ struct BlockGemmARegBRegCRegV1
...
@@ -157,7 +159,7 @@ struct BlockGemmARegBRegCRegV1
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
// warp GEMM
W
G
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
W
arpGemm
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
c_block_tensor
.
set_y_sliced_thread_data
(
...
@@ -180,7 +182,7 @@ struct BlockGemmARegBRegCRegV1
...
@@ -180,7 +182,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
W
G
::
CWarpDstrEncoding
{});
c_block_outer_dstr_encoding
,
typename
W
arpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
return
c_block_tensor
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
View file @
b2c7d774
...
@@ -45,17 +45,17 @@ struct GemmPipelineAgBgCrImplBase
...
@@ -45,17 +45,17 @@ struct GemmPipelineAgBgCrImplBase
{
{
load_tile
(
dst_block_tile
,
lds_tile_window
);
load_tile
(
dst_block_tile
,
lds_tile_window
);
}
}
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
{
{
// A tile in LDS
// A tile in LDS
ADataType
*
__restrict__
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
ADataType
*
__restrict__
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_least_multiple
(
integer_least_multiple
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
);
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
);
// B tile in LDS
// B tile in LDS
BDataType
*
__restrict__
p_b_lds
=
static_cast
<
BDataType
*>
(
BDataType
*
__restrict__
p_b_lds
=
static_cast
<
BDataType
*>
(
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
b2c7d774
...
@@ -72,8 +72,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -72,8 +72,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
bool
isDoubleSmemBuffer
=
Problem
::
isDoubleSmemBuffer
;
static
constexpr
bool
isDoubleSmemBuffer
=
Problem
::
isDoubleSmemBuffer
;
static
constexpr
bool
isDoubleSmemBuffer
=
Problem
::
isDoubleSmemBuffer
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
View file @
b2c7d774
...
@@ -9,8 +9,30 @@
...
@@ -9,8 +9,30 @@
namespace
ck_tile
{
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
>
struct
BaseGemmPipelineAgBgCrCompV4
{
static
constexpr
index_t
PrefetchStages
=
3
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Two
;
}
};
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV4
:
public
BaseGemmPipelineAgBgCrCompV
3
<
Problem
>
struct
GemmPipelineAgBgCrCompV4
:
public
BaseGemmPipelineAgBgCrCompV
4
<
Problem
>
{
{
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
...
@@ -35,9 +57,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -35,9 +57,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
P
roblem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeA
=
P
olicy
::
template
GetVectorSizeA
<
Problem
>()
;
static
constexpr
index_t
VectorSizeB
=
P
roblem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeB
=
P
olicy
::
template
GetVectorSizeB
<
Problem
>()
;
static
constexpr
index_t
VectorSizeC
=
P
roblem
::
VectorSizeC
;
static
constexpr
index_t
VectorSizeC
=
P
olicy
::
template
GetVectorSizeC
<
Problem
>()
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
...
@@ -54,7 +76,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -54,7 +76,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
template
IsTransposeC
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
struct
PipelineImpl
:
public
PipelineImplBase
...
@@ -115,12 +140,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -115,12 +140,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
ignore
=
i
;
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
/
num_issue
,
0
);
// DS read : 2
0x100
,
num_ds_read_inst
/
num_issue
,
0
);
// DS read : 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA: 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA: 1
__builtin_amdgcn_sched_group_barrier
(
__builtin_amdgcn_sched_group_barrier
(
0x200
,
num_ds_write_inst
/
num_issue
,
0
);
// DS write : 1
0x200
,
num_ds_write_inst
/
num_issue
,
0
);
// DS write : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read :1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read :1
__builtin_amdgcn_sched_group_barrier
(
__builtin_amdgcn_sched_group_barrier
(
0x008
,
C_MFMA_Inst_Num
/
num_issue
-
3
,
0
);
// MFMA : 5
0x008
,
C_MFMA_Inst_Num
/
num_issue
-
3
,
0
);
// MFMA : 5
});
});
...
@@ -147,11 +172,22 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -147,11 +172,22 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
"wrong!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
constexpr
bool
is_a_col_major
=
NPerBlock
==
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
////////////// global window & register /////////////////
////////////// global window & register /////////////////
// A DRAM tile window for load
// A DRAM tile window for load
...
@@ -176,37 +212,33 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -176,37 +212,33 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
ABlockTile
a_global_load_tile
;
ABlockTile
a_global_load_tile
;
BBlockTile
b_global_load_tile
;
BBlockTile
b_global_load_tile
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// global prefetch 0
// global prefetch 0
// global read 0
// global read 0
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
////////////// LDS desc, window & register /////////////////
////////////// LDS desc, window & register /////////////////
auto
&&
[
a_lds_block0
,
b_lds_block0
]
=
Base
::
GetABLdsTensorViews
(
p_smem_0
);
auto
&&
[
a_lds_block0
,
b_lds_block0
]
=
Base
::
GetABLdsTensorViews
(
p_smem_0
);
auto
&&
[
a_lds_block1
,
b_lds_block1
]
=
Base
::
GetABLdsTensorViews
(
p_smem_1
);
auto
&&
[
a_lds_block1
,
b_lds_block1
]
=
Base
::
GetABLdsTensorViews
(
p_smem_1
);
auto
a_copy_lds_window0
=
auto
a_copy_lds_window0
=
make_tile_window
(
make_tile_window
(
a_lds_block0
,
a_lds_block0
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
auto
a_copy_lds_window1
=
make_tile_window
(
ABlockTileDistr
);
a_lds_block1
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
a_copy_lds_window1
=
auto
b_copy_lds_window0
=
make_tile_window
(
make_tile_window
(
a_lds_block1
,
b_lds_block0
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
auto
b_copy_lds_window1
=
make_tile_window
(
ABlockTileDistr
);
b_lds_block1
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
b_copy_lds_window0
=
make_tile_window
(
b_lds_block0
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BBlockTileDistr
);
auto
b_copy_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BBlockTileDistr
);
// Block GEMM
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
block_gemm
=
BlockGemm
();
...
@@ -216,11 +248,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -216,11 +248,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
if
constexpr
(
is_a_col_major
)
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
}
// global read 1
// global read 1
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
block_sync_lds
();
block_sync_lds
();
...
@@ -262,11 +315,31 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -262,11 +315,31 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
if
constexpr
(
is_a_col_major
)
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
if
(
HasHotLoop
)
if
(
HasHotLoop
)
{
{
...
@@ -280,11 +353,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -280,11 +353,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
if
constexpr
(
is_a_col_major
)
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
// gemm
// gemm
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
HotLoopScheduler
();
HotLoopScheduler
();
...
@@ -296,11 +393,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -296,11 +393,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
if
constexpr
(
is_a_col_major
)
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
// gemm
// gemm
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
HotLoopScheduler
();
HotLoopScheduler
();
...
@@ -318,8 +439,28 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -318,8 +439,28 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
block_sync_lds
();
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
if
constexpr
(
is_a_col_major
)
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
}
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
}
// 2
// 2
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
View file @
b2c7d774
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
b2c7d774
...
@@ -33,7 +33,7 @@ struct GemmPipelineProblemBase
...
@@ -33,7 +33,7 @@ struct GemmPipelineProblemBase
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
isDoubleSmemBuffer
=
Gemm
Traits
::
isDoubleSmemBuffer
;
static
constexpr
bool
isDoubleSmemBuffer
=
Traits
::
isDoubleSmemBuffer
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
...
@@ -163,6 +163,8 @@ struct UniversalGemmPipelineProblem
...
@@ -163,6 +163,8 @@ struct UniversalGemmPipelineProblem
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
isDoubleSmemBuffer
=
Traits
::
isDoubleSmemBuffer
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
auto
TailNum
=
TailNum_
;
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
b2c7d774
...
@@ -22,8 +22,6 @@ struct TileGemmTraits
...
@@ -22,8 +22,6 @@ struct TileGemmTraits
static
constexpr
bool
isDoubleSmemBuffer
=
isDoubleSmemBuffer_
;
static
constexpr
bool
isDoubleSmemBuffer
=
isDoubleSmemBuffer_
;
static
constexpr
bool
isDoubleSmemBuffer
=
isDoubleSmemBuffer_
;
// TODO this can't be hardcoded here! Should be in policy!
// TODO this can't be hardcoded here! Should be in policy!
static
constexpr
int
_VectorSize
=
16
;
static
constexpr
int
_VectorSize
=
16
;
...
@@ -37,6 +35,7 @@ struct TileGemmTraits
...
@@ -37,6 +35,7 @@ struct TileGemmTraits
template
<
bool
kPadM_
,
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
kPadN_
,
bool
kPadK_
,
bool
kPadK_
,
bool
isDoubleSmemBuffer_
,
typename
ALayout_
,
typename
ALayout_
,
typename
BLayout_
,
typename
BLayout_
,
typename
CLayout_
,
typename
CLayout_
,
...
@@ -47,6 +46,8 @@ struct TileGemmUniversalTraits
...
@@ -47,6 +46,8 @@ struct TileGemmUniversalTraits
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
bool
isDoubleSmemBuffer
=
isDoubleSmemBuffer_
;
using
ALayout
=
ALayout_
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
BLayout
=
BLayout_
;
using
CLayout
=
CLayout_
;
using
CLayout
=
CLayout_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment