Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c87aa6c8
"vscode:/vscode.git/clone" did not exist on "33b3c0f85ffb647a1fc831c59c112bcfca5c06b8"
Unverified
Commit
c87aa6c8
authored
Nov 26, 2024
by
Illia Silin
Committed by
GitHub
Nov 26, 2024
Browse files
Merge branch 'develop' into codegen_hiprtc
parents
60afb522
b70f367f
Changes
267
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2355 additions
and
718 deletions
+2355
-718
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
+46
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
+48
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-0
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
.../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+661
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+50
-20
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+9
-9
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+44
-19
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+302
-66
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+118
-38
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+357
-312
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+10
-6
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+75
-55
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+180
-45
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+353
-104
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+29
-29
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
+57
-11
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+1
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+11
-3
include/ck_tile/ops/smoothquant.hpp
include/ck_tile/ops/smoothquant.hpp
+1
-0
No files found.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
0 → 100644
View file @
c87aa6c8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// TODO: alow 2 gemm have different type
template
<
typename
ADataType_
,
typename
GDataType_
,
typename
DDataType_
,
typename
AccDataType_
,
typename
ODataType_
,
typename
AScaleDataType_
,
typename
GScaleDataType_
,
typename
DScaleDataType_
,
typename
YSmoothScaleDataType_
,
typename
TopkWeightDataType_
,
typename
IndexDataType_
,
// data type for all indexing
typename
GateActivation_
,
// = ck_tile::element_wise::Silu,
typename
BlockShape_
,
// shoule be FusedMoeGemmShape
typename
Traits_
>
struct
FusedMoeGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
GDataType
=
remove_cvref_t
<
GDataType_
>
;
using
DDataType
=
remove_cvref_t
<
DDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
AScaleDataType
=
remove_cvref_t
<
AScaleDataType_
>
;
using
GScaleDataType
=
remove_cvref_t
<
GScaleDataType_
>
;
using
DScaleDataType
=
remove_cvref_t
<
DScaleDataType_
>
;
using
YSmoothScaleDataType
=
remove_cvref_t
<
YSmoothScaleDataType_
>
;
using
TopkWeightDataType
=
remove_cvref_t
<
TopkWeightDataType_
>
;
using
IndexDataType
=
remove_cvref_t
<
IndexDataType_
>
;
// the input for next gemm should have same time as
using
YDataType
=
ADataType
;
using
GateActivation
=
remove_cvref_t
<
GateActivation_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
0 → 100644
View file @
c87aa6c8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
enum
class
FusedMoeGemmWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
no_permute
=
0
,
b_nr_kr_kw_nw_kv
=
1
,
// 0,1,3,4,2,5
b_nr_kr_waveflatten
=
b_nr_kr_kw_nw_kv
,
};
template
<
bool
IsGateOnly_
,
bool
UseSmoothQuant_
,
index_t
OAtomic_
,
// 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
FusedMoeGemmWeightPermuteEnum
PermuteEnum_
=
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
,
bool
PadHiddenSize_
=
false
,
bool
PadIntermediateSize_
=
false
>
struct
FusedMoeGemmTraits
{
// Gate+Up or Gate only
static
constexpr
bool
IsGateOnly
=
IsGateOnly_
;
static
constexpr
bool
UseSmoothQuant
=
UseSmoothQuant_
;
static
constexpr
index_t
OAtomic
=
OAtomic_
;
static
constexpr
FusedMoeGemmWeightPermuteEnum
PermuteEnum
=
PermuteEnum_
;
static
constexpr
bool
PadHiddenSize
=
PadHiddenSize_
;
static
constexpr
bool
PadIntermediateSize
=
PadIntermediateSize_
;
};
// Note: this need to be a bit mask
enum
class
FusedMoeGemmPipelineSequencerEnum
{
SLD_A
=
1
<<
0
,
// shared load a
SLD_B
=
1
<<
1
,
GLD_A
=
1
<<
2
,
// global load a
GLD_B
=
1
<<
3
,
SST_A
=
1
<<
4
,
// shared store a
SST_B
=
1
<<
5
,
GST_O
=
1
<<
6
,
// global store out
};
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
c87aa6c8
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
...
...
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
0 → 100644
View file @
c87aa6c8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmASmemBSmemCRegV1DefaultPolicy
>
struct
BlockUniversalGemmAsBsCr
{
private:
// TODO: This should be in Policy - UniversalGemmPolicyBase ?
template
<
typename
PipelineProblem_
,
typename
GemmPolicy_
>
struct
GemmTraits_
{
using
Problem
=
remove_cvref_t
<
PipelineProblem_
>
;
using
Policy
=
remove_cvref_t
<
GemmPolicy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static_assert
(
MWarp
==
BlockGemmShape
::
BlockWarps
::
at
(
number
<
0
>
{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!"
);
static_assert
(
NWarp
==
BlockGemmShape
::
BlockWarps
::
at
(
number
<
1
>
{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!"
);
static_assert
(
WarpGemm
::
kM
==
BlockGemmShape
::
WarpTile
::
at
(
number
<
0
>
{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!"
);
static_assert
(
WarpGemm
::
kN
==
BlockGemmShape
::
WarpTile
::
at
(
number
<
1
>
{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!"
);
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WarpGemm
::
kK
;
static_assert
(
MIterPerWarp
*
MWarp
*
WarpGemm
::
kM
==
MPerBlock
,
"Error! Warps should cover all Block tile!"
);
static_assert
(
NIterPerWarp
*
NWarp
*
WarpGemm
::
kN
==
NPerBlock
,
"Error! Warps should cover all Block tile!"
);
static
constexpr
index_t
MPerBlockPerIter
=
MWarp
*
WarpGemm
::
kM
;
static
constexpr
index_t
NPerBlockPerIter
=
NWarp
*
WarpGemm
::
kN
;
static
constexpr
index_t
KPerBlockPerIter
=
WarpGemm
::
kK
;
using
AWarpTileDistr
=
remove_cvref_t
<
decltype
(
make_static_tile_distribution
(
typename
WarpGemm
::
AWarpDstrEncoding
{}))
>
;
using
BWarpTileDistr
=
remove_cvref_t
<
decltype
(
make_static_tile_distribution
(
typename
WarpGemm
::
BWarpDstrEncoding
{}))
>
;
using
AWarpTile
=
remove_cvref_t
<
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
AWarpTileDistr
{}))
>
;
using
BWarpTile
=
remove_cvref_t
<
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BWarpTileDistr
{}))
>
;
// TODO: Should we have two policies? Interwave & Intrawave ??
static
constexpr
index_t
InterWaveSchedulingMacClusters
=
1
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
WarpGemm
::
kK
*
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
};
public:
using
Traits
=
GemmTraits_
<
Problem_
,
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Traits
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Traits
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Traits
::
CDataType
>
;
using
WarpGemm
=
remove_cvref_t
<
typename
Traits
::
WarpGemm
>
;
static
constexpr
index_t
KIterPerWarp
=
Traits
::
KIterPerWarp
;
static
constexpr
index_t
MIterPerWarp
=
Traits
::
MIterPerWarp
;
static
constexpr
index_t
NIterPerWarp
=
Traits
::
NIterPerWarp
;
static
constexpr
index_t
MWarp
=
Traits
::
MWarp
;
static
constexpr
index_t
NWarp
=
Traits
::
NWarp
;
static
constexpr
auto
Scheduler
=
Traits
::
Scheduler
;
private:
template
<
GemmPipelineScheduler
Scheduler
,
typename
GemmTraits
>
struct
BlockGemmImpl
{
};
template
<
typename
GemmTraits
>
struct
BlockGemmImpl
<
GemmPipelineScheduler
::
Default
,
GemmTraits
>
{
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
std
::
is_same_v
<
typename
GemmTraits
::
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!"
);
static_assert
(
std
::
is_same_v
<
typename
GemmTraits
::
ADataType
,
typename
ASmemBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
typename
GemmTraits
::
BDataType
,
typename
BSmemBlockWindow
::
DataType
>
,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"
);
static_assert
(
GemmTraits
::
MPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
GemmTraits
::
NPerBlock
==
BSmemBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
GemmTraits
::
KPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
number
<
1
>
{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"
);
const
index_t
iMWarp
=
get_warp_id
()
/
GemmTraits
::
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
-
(
iMWarp
*
GemmTraits
::
NWarp
);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
GemmTraits
::
WarpGemm
::
kM
>
{},
number
<
GemmTraits
::
WarpGemm
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
GemmTraits
::
WarpGemm
::
kM
,
0
},
make_static_tile_distribution
(
typename
GemmTraits
::
WarpGemm
::
AWarpDstrEncoding
{}));
using
AWarpWindow
=
remove_cvref_t
<
decltype
(
a_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
AWarpTile
::
get_num_of_dimension
()
==
AWarpWindow
::
get_num_of_dimension
(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
AWarpTile
::
get_lengths
()
==
AWarpWindow
{}.
get_window_lengths
(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
AWarpWindow
,
GemmTraits
::
KIterPerWarp
>
,
GemmTraits
::
MIterPerWarp
>
a_warp_windows
;
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
GemmTraits
::
WarpGemm
::
kN
>
{},
number
<
GemmTraits
::
WarpGemm
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
GemmTraits
::
WarpGemm
::
kN
,
0
},
make_static_tile_distribution
(
typename
GemmTraits
::
WarpGemm
::
BWarpDstrEncoding
{}));
using
BWarpWindow
=
remove_cvref_t
<
decltype
(
b_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
BWarpTile
::
get_num_of_dimension
()
==
BWarpWindow
::
get_num_of_dimension
(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
BWarpTile
::
get_lengths
()
==
BWarpWindow
{}.
get_window_lengths
(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
BWarpWindow
,
GemmTraits
::
KIterPerWarp
>
,
GemmTraits
::
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
GemmTraits
::
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
// TODO: I don't have to move 0,0 window!
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
GemmTraits
::
MPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
GemmTraits
::
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
GemmTraits
::
NPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
using
CWarpDstr
=
typename
GemmTraits
::
WarpGemm
::
CWarpDstr
;
using
CWarpTensor
=
typename
GemmTraits
::
WarpGemm
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
GemmTraits
::
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
const
auto
a_warp_tile
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
static_for
<
0
,
GemmTraits
::
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
const
auto
b_warp_tile
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor-
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
typename
GemmTraits
::
WarpGemm
{}(
c_warp_tensor
,
a_warp_tile
,
b_warp_tile
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
};
template
<
typename
GemmTraits
>
struct
BlockGemmImpl
<
GemmPipelineScheduler
::
Intrawave
,
GemmTraits
>
{
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
AWarpTile
,
GemmTraits
::
KIterPerWarp
>
,
GemmTraits
::
MIterPerWarp
>
a_warp_tiles_
;
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
BWarpTile
,
GemmTraits
::
KIterPerWarp
>
,
GemmTraits
::
NIterPerWarp
>
b_warp_tiles_
;
template
<
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
GemmTraits
::
MPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
GemmTraits
::
NPerBlock
==
BSmemBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
GemmTraits
::
KPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
number
<
1
>
{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"
);
static_assert
(
std
::
is_same_v
<
typename
GemmTraits
::
ADataType
,
typename
ASmemBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
typename
GemmTraits
::
BDataType
,
typename
BSmemBlockWindow
::
DataType
>
,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"
);
const
index_t
iMWarp
=
get_warp_id
()
/
GemmTraits
::
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
-
(
iMWarp
*
GemmTraits
::
NWarp
);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
GemmTraits
::
WarpGemm
::
kM
>
{},
number
<
GemmTraits
::
WarpGemm
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
GemmTraits
::
WarpGemm
::
kM
,
0
},
make_static_tile_distribution
(
typename
GemmTraits
::
WarpGemm
::
AWarpDstrEncoding
{}));
using
AWarpWindow
=
remove_cvref_t
<
decltype
(
a_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
AWarpTile
::
get_num_of_dimension
()
==
AWarpWindow
::
get_num_of_dimension
(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
AWarpTile
::
get_lengths
()
==
AWarpWindow
{}.
get_window_lengths
(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
AWarpWindow
,
GemmTraits
::
KIterPerWarp
>
,
GemmTraits
::
MIterPerWarp
>
a_warp_windows
;
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
GemmTraits
::
WarpGemm
::
kN
>
{},
number
<
GemmTraits
::
WarpGemm
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
GemmTraits
::
WarpGemm
::
kN
,
0
},
make_static_tile_distribution
(
typename
GemmTraits
::
WarpGemm
::
BWarpDstrEncoding
{}));
using
BWarpWindow
=
remove_cvref_t
<
decltype
(
b_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
BWarpTile
::
get_num_of_dimension
()
==
BWarpWindow
::
get_num_of_dimension
(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
BWarpTile
::
get_lengths
()
==
BWarpWindow
{}.
get_window_lengths
(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
BWarpWindow
,
GemmTraits
::
KIterPerWarp
>
,
GemmTraits
::
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
GemmTraits
::
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
// TODO: I don't have to move 0,0 window!
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
GemmTraits
::
MPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
GemmTraits
::
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
GemmTraits
::
NPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
GemmTraits
::
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
load_tile
(
a_warp_tiles_
(
mIter
)(
kIter
),
a_warp_windows
(
mIter
)(
kIter
));
});
static_for
<
0
,
GemmTraits
::
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
load_tile
(
b_warp_tiles_
(
nIter
)(
kIter
),
b_warp_windows
(
nIter
)(
kIter
));
});
});
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
[[
maybe_unused
]]
const
ASmemBlockWindow
&
a_block_window
,
[[
maybe_unused
]]
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
std
::
is_same_v
<
typename
GemmTraits
::
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!"
);
using
CWarpDstr
=
typename
GemmTraits
::
WarpGemm
::
CWarpDstr
;
using
CWarpTensor
=
typename
GemmTraits
::
WarpGemm
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
GemmTraits
::
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
GemmTraits
::
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read C warp tensor from C block tensor-
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
typename
GemmTraits
::
WarpGemm
{}(
c_warp_tensor
,
a_warp_tiles_
[
mIter
][
kIter
],
b_warp_tiles_
[
nIter
][
kIter
]);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
};
template
<
typename
GemmTraits
>
struct
BlockGemmImpl
<
GemmPipelineScheduler
::
Interwave
,
GemmTraits
>
{
static
constexpr
index_t
KPerThread
=
GemmTraits
::
KPerThread
;
static
constexpr
index_t
NumMacClusters
=
GemmTraits
::
InterWaveSchedulingMacClusters
;
static
constexpr
index_t
KPerInnerLoop
=
ck_tile
::
max
(
KPerThread
/
NumMacClusters
,
GemmTraits
::
KPack
);
// TODO: do we really need this?? Are there any cases when this would be >=1 ??
// Would we need InterWaveSchedulingMacClusters > 1 ???
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPerInnerLoop
;
static
constexpr
index_t
KInnerLoopIter
=
KPerInnerLoop
/
GemmTraits
::
KPack
;
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
AWarpTile
,
KInnerLoopIter
>
,
GemmTraits
::
MIterPerWarp
>
a_warp_tiles_
;
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
BWarpTile
,
KInnerLoopIter
>
,
GemmTraits
::
NIterPerWarp
>
b_warp_tiles_
;
template
<
index_t
KIdx
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
GemmTraits
::
MPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
GemmTraits
::
NPerBlock
==
BSmemBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
GemmTraits
::
KPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
number
<
1
>
{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"
);
static_assert
(
std
::
is_same_v
<
typename
GemmTraits
::
ADataType
,
typename
ASmemBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
typename
GemmTraits
::
BDataType
,
typename
BSmemBlockWindow
::
DataType
>
,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"
);
const
index_t
iMWarp
=
get_warp_id
()
/
GemmTraits
::
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
-
(
iMWarp
*
GemmTraits
::
NWarp
);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
GemmTraits
::
WarpGemm
::
kM
>
{},
number
<
GemmTraits
::
WarpGemm
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
GemmTraits
::
WarpGemm
::
kM
,
KIdx
*
KPerInnerLoop
},
make_static_tile_distribution
(
typename
GemmTraits
::
WarpGemm
::
AWarpDstrEncoding
{}));
using
AWarpWindow
=
remove_cvref_t
<
decltype
(
a_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
AWarpTile
::
get_num_of_dimension
()
==
AWarpWindow
::
get_num_of_dimension
(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
AWarpTile
::
get_lengths
()
==
AWarpWindow
{}.
get_window_lengths
(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
AWarpWindow
,
KInnerLoopIter
>
,
GemmTraits
::
MIterPerWarp
>
a_warp_windows
;
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
GemmTraits
::
WarpGemm
::
kN
>
{},
number
<
GemmTraits
::
WarpGemm
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
GemmTraits
::
WarpGemm
::
kN
,
KIdx
*
KPerInnerLoop
},
make_static_tile_distribution
(
typename
GemmTraits
::
WarpGemm
::
BWarpDstrEncoding
{}));
using
BWarpWindow
=
remove_cvref_t
<
decltype
(
b_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
BWarpTile
::
get_num_of_dimension
()
==
BWarpWindow
::
get_num_of_dimension
(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
BWarpTile
::
get_lengths
()
==
BWarpWindow
{}.
get_window_lengths
(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
BWarpWindow
,
KInnerLoopIter
>
,
GemmTraits
::
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
GemmTraits
::
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
GemmTraits
::
MPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
GemmTraits
::
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
GemmTraits
::
NPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
// TODO check if a_warp_tiles has same desc as a_warp_window
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
GemmTraits
::
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
load_tile
(
a_warp_tiles_
(
mIter
)(
kIter
),
a_warp_windows
(
mIter
)(
kIter
));
});
static_for
<
0
,
GemmTraits
::
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
load_tile
(
b_warp_tiles_
(
nIter
)(
kIter
),
b_warp_windows
(
nIter
)(
kIter
));
});
});
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
std
::
is_same_v
<
typename
GemmTraits
::
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!"
);
using
CWarpDstr
=
typename
GemmTraits
::
WarpGemm
::
CWarpDstr
;
using
CWarpTensor
=
typename
GemmTraits
::
WarpGemm
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
kIter
)
{
LocalPrefetch
<
kIter
.
value
>
(
a_block_window
,
b_block_window
);
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
// hijacking MAC resource from other workgroups and reducing the chance of
// latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if
constexpr
(
kIter
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kInnerIter
)
{
static_for
<
0
,
GemmTraits
::
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
GemmTraits
::
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read C warp tensor from C block tensor-
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if
constexpr
(
kIter
.
value
==
KRepeat
-
1
&&
kInnerIter
.
value
==
KInnerLoopIter
-
1
&&
mIter
.
value
==
GemmTraits
::
MIterPerWarp
-
1
&&
nIter
.
value
==
GemmTraits
::
NIterPerWarp
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// warp GEMM
typename
GemmTraits
::
WarpGemm
{}(
c_warp_tensor
,
a_warp_tiles_
[
mIter
][
kInnerIter
],
b_warp_tiles_
[
nIter
][
kInnerIter
]);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
if
constexpr
(
kInnerIter
.
value
==
0
&&
mIter
.
value
==
0
&&
nIter
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
}
};
public:
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
block_gemm_impl_
.
template
LocalPrefetch
(
a_block_window
,
b_block_window
);
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
block_gemm_impl_
.
template
operator
()(
c_block_tensor
,
a_block_window
,
b_block_window
);
}
// C = A * B
template
<
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
auto
c_block_tensor
=
MakeCBlockTile
();
block_gemm_impl_
.
template
operator
()(
c_block_tensor
,
a_block_window
,
b_block_window
);
return
c_block_tensor
;
}
private:
BlockGemmImpl
<
Scheduler
,
Traits
>
block_gemm_impl_
{};
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
c87aa6c8
...
@@ -115,12 +115,22 @@ struct GemmKernel
...
@@ -115,12 +115,22 @@ struct GemmKernel
}
}
}();
}();
auto
a_pad_view
=
pad_tensor_view
(
auto
a_pad_view
=
[
&
]()
{
a_tensor_view
,
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
// somehow clang-format is splitting below line into multiple.
return
pad_tensor_view
(
// clang-format off
a_tensor_view
,
sequence
<
false
,
GemmPipeline
::
kPadA
>
{});
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
// clang-format on
auto
a_block_window
=
make_tile_window
(
auto
a_block_window
=
make_tile_window
(
...
@@ -128,12 +138,22 @@ struct GemmKernel
...
@@ -128,12 +138,22 @@ struct GemmKernel
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
{
i_m
,
0
});
auto
b_pad_view
=
pad_tensor_view
(
auto
b_pad_view
=
[
&
]()
{
b_tensor_view
,
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
// clang-format off
return
pad_tensor_view
(
sequence
<
false
,
GemmPipeline
::
kPadB
>
{});
b_tensor_view
,
// clang-format on
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
auto
b_block_window
=
make_tile_window
(
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
b_pad_view
,
...
@@ -171,18 +191,28 @@ struct GemmKernel
...
@@ -171,18 +191,28 @@ struct GemmKernel
}
}
}();
}();
auto
c_pad_view
=
pad_tensor_view
(
auto
c_pad_view
=
[
&
]()
{
c_tensor_view
,
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
// clang-format off
return
pad_tensor_view
(
sequence
<
false
,
GemmPipeline
::
kPadC
>
{});
c_tensor_view
,
// clang-format on
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
auto
c_block_window
=
make_tile_window
(
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
{
i_m
,
i_n
});
EpiloguePipeline
{}(
c_b
lock
_w
indow
,
c_block_tile
);
EpiloguePipeline
{}(
CB
lock
W
indow
_pad
,
c_block_tile
);
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
c87aa6c8
...
@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
// Where is the right place for HasHotLoop and TailNum ???
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
...
@@ -247,8 +247,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -247,8 +247,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
// Block GEMM
constexpr
auto
block_gemm
=
BlockGemm
();
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
...
@@ -290,7 +290,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -290,7 +290,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_sync_lds
();
//
block_gemm.LocalPrefetch();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
block_sync_lds
();
...
@@ -318,7 +318,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -318,7 +318,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_sync_lds
();
//
block_gemm.LocalPrefetch();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
block_sync_lds
();
...
@@ -331,14 +331,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -331,14 +331,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
});
});
block_sync_lds
();
block_sync_lds
();
//
block_gemm.LocalPrefetch();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
{
block_sync_lds
();
block_sync_lds
();
//
block_gemm.LocalPrefetch();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
View file @
c87aa6c8
...
@@ -11,6 +11,7 @@ namespace ck_tile {
...
@@ -11,6 +11,7 @@ namespace ck_tile {
enum
struct
GemmPipelineScheduler
enum
struct
GemmPipelineScheduler
{
{
Default
,
Intrawave
,
Intrawave
,
Interwave
,
Interwave
,
};
};
...
@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch
...
@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch
{
{
switch
(
s
)
switch
(
s
)
{
{
case
ck_tile
::
GemmPipelineScheduler
::
Default
:
os
<<
"Default"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Intrawave
:
os
<<
"Intrawave"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Intrawave
:
os
<<
"Intrawave"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Interwave
:
os
<<
"Interwave"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Interwave
:
os
<<
"Interwave"
;
break
;
default:
os
<<
""
;
default:
os
<<
""
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
c87aa6c8
...
@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
{
...
@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
// A LDS tile window for store
auto
a_copy_lds_window
=
auto
a_copy_lds_window
=
make_tile_window
(
make_tile_window
(
a_lds_block
,
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
// B DRAM tile window for load
auto
b_copy_dram_window
=
auto
b_copy_dram_window
=
...
@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
// B LDS tile window for store
auto
b_copy_lds_window
=
auto
b_copy_lds_window
=
make_tile_window
(
make_tile_window
(
b_lds_block
,
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
auto
a_lds_gemm_window
=
make_tile_window
(
...
@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
// LDS write 0
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockDescriptor
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
}
else
{
store_tile
(
a_copy_lds_window
,
tile_elementwise_in
(
a_element_func
,
a_block_tile
));
}
// LDS write 0
// LDS write 0
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
else
{
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_block_tile
));
}
}
}
index_t
iCounter
=
num_loop
-
1
;
index_t
iCounter
=
num_loop
-
1
;
...
@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// LDS write i + 1
// LDS write i + 1
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
}
else
{
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
iCounter
--
;
iCounter
--
;
}
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
c87aa6c8
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -11,6 +12,7 @@ namespace ck_tile {
...
@@ -11,6 +12,7 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead
// Default policy class should not be templated, put template on member functions instead
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
{
#if 0
#if 0
// 2d
// 2d
template <typename Problem>
template <typename Problem>
...
@@ -51,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -51,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
...
@@ -116,6 +119,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -116,6 +119,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return
smem_size
;
return
smem_size
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
#elif 1
#elif 1
// fake XOR
// fake XOR
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -192,88 +209,307 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -192,88 +209,307 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
{
#if 1 // coalesce reading for each blocks
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
MPerBlock
/
M1
;
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
return
make_static_tile_distribution
(
static_assert
(
KPack
%
K3
==
0
);
tile_distribution_encoding
<
sequence
<
1
>
,
constexpr
index_t
K2
=
KPack
/
K3
;
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
))
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
{
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
sequence
<
1
,
2
>
,
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
sequence
<
0
,
1
>>
{});
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
#else // coalesce reading for each warps
return
make_static_tile_distribution
(
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
tile_distribution_encoding
<
sequence
<
1
>
,
constexpr
index_t
M1
=
kMPerBlock
/
(
M2
*
M0
);
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
return
make_static_tile_distribution
(
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tile_distribution_encoding
<
sequence
<
1
>
,
sequence
<
2
,
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
sequence
<
3
,
1
>>
{});
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
}
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
else
sequence
<
1
,
2
>
,
{
sequence
<
1
,
1
>>
{});
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
#endif
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
static_assert
(
N0
*
N1
*
N2
==
NPerBlock
,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
static_assert
(
N0
*
N1
*
N2
==
NPerBlock
,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
#if 1 // coalesce reading for each blocks
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
K3
=
total_pixels
/
N1
;
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
constexpr
index_t
kKPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
return
make_static_tile_distribution
(
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
tile_distribution_encoding
<
sequence
<
1
>
,
{
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
return
make_static_tile_distribution
(
sequence
<
0
,
1
>>
{});
tile_distribution_encoding
<
sequence
<
1
>
,
#else // coalesce reading for each warps
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
constexpr
index_t
N1
=
kNPerBlock
/
(
N2
*
N0
);
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
return
make_static_tile_distribution
(
sequence
<
1
,
3
>>
{});
tile_distribution_encoding
<
sequence
<
1
>
,
}
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
else
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
{
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
sequence
<
1
,
2
>
,
constexpr
index_t
K2_m
=
K2
/
K1
;
sequence
<
1
,
1
>>
{});
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
#endif
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
{
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1DefaultPolicy
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
kMPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
return
BlockGemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
constexpr
bool
TransposeC
=
false
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I2
=
number
<
2
>
{};
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
BlockUniversalGemmAsBsCr
<
Problem
,
BlockGemmPolicy
>
{};
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
c87aa6c8
...
@@ -3,40 +3,135 @@
...
@@ -3,40 +3,135 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
static
constexpr
int
_VectorSize
=
16
;
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
TileGemmTraits_
>
typename
TileGemmTraits_
>
struct
GemmPipelineProblem
struct
GemmPipelineProblem
Base
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
VectorLoadSize
=
GemmTraits
::
_VectorSize
;
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
static
constexpr
bool
kPadM
=
GemmTraits
::
kPadM
;
static
constexpr
bool
kPadN
=
GemmTraits
::
kPadN
;
static
constexpr
bool
kPadK
=
GemmTraits
::
kPadK
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
pixels_per_thread
=
BlockGemmShape
::
kM
*
BlockGemmShape
::
kK
/
kBlockSize
;
return
pixels_per_thread
<
VectorLoadSize
/
sizeof
(
ADataType
)
?
pixels_per_thread
:
VectorLoadSize
/
sizeof
(
ADataType
);
}
else
{
return
VectorLoadSize
/
sizeof
(
ADataType
);
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentB
()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
pixels_per_thread
=
BlockGemmShape
::
kN
*
BlockGemmShape
::
kK
/
kBlockSize
;
return
pixels_per_thread
<
VectorLoadSize
/
sizeof
(
BDataType
)
?
pixels_per_thread
:
VectorLoadSize
/
sizeof
(
BDataType
);
}
else
{
return
VectorLoadSize
/
sizeof
(
BDataType
);
}
}
static
constexpr
index_t
VectorSizeA
=
kPadA
?
1
:
_VectorSize
/
sizeof
(
ADataType
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentC
()
static
constexpr
index_t
VectorSizeB
=
kPadB
?
1
:
_VectorSize
/
sizeof
(
BDataType
);
{
static
constexpr
index_t
VectorSizeC
=
kPadC
?
1
:
_VectorSize
/
sizeof
(
CDataType
);
if
constexpr
(
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N2
=
std
::
min
(
BlockGemmShape
::
kN
/
N1
,
get_warp_size
());
constexpr
index_t
M0
=
get_warp_size
()
/
N2
;
constexpr
index_t
M1
=
BlockGemmShape
::
kM
/
M0
;
return
std
::
min
(
M1
,
static_cast
<
index_t
>
(
VectorLoadSize
/
sizeof
(
CDataType
)));
}
else
{
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
std
::
min
(
BlockGemmShape
::
kM
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
BlockGemmShape
::
kN
/
N0
;
return
std
::
min
(
N1
,
static_cast
<
index_t
>
(
VectorLoadSize
/
sizeof
(
CDataType
)));
}
}
static
constexpr
index_t
VectorSizeA
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
kPadK
?
1
:
GetAlignmentA
();
}
else
{
return
kPadM
?
1
:
GetAlignmentA
();
}
}();
static
constexpr
index_t
VectorSizeB
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
kPadN
?
1
:
GetAlignmentB
();
}
else
{
return
kPadK
?
1
:
GetAlignmentB
();
}
}();
static
constexpr
index_t
VectorSizeC
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
kPadN
?
1
:
GetAlignmentC
();
}
else
{
return
kPadM
?
1
:
GetAlignmentC
();
}
}();
};
};
// Alias for GemmPipelineProblem
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
TileGemmTraits_
>
using
GemmPipelineProblem
=
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
;
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
...
@@ -45,30 +140,15 @@ template <typename ADataType_,
...
@@ -45,30 +140,15 @@ template <typename ADataType_,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
true
,
bool
HasHotLoop_
=
true
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
struct
UniversalGemmPipelineProblem
:
public
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
static
constexpr
auto
TailNum
=
TailNum_
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
static
constexpr
index_t
VectorSizeA
=
kPadA
?
_VectorSize
/
sizeof
(
ADataType
)
:
1
;
static
constexpr
index_t
VectorSizeB
=
kPadB
?
_VectorSize
/
sizeof
(
BDataType
)
:
1
;
static
constexpr
index_t
VectorSizeC
=
kPadC
?
_VectorSize
/
sizeof
(
CDataType
)
:
1
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
c87aa6c8
...
@@ -9,12 +9,8 @@
...
@@ -9,12 +9,8 @@
namespace
ck_tile
{
namespace
ck_tile
{
// UniversalGemm Policy
// UniversalGemm Policy
template
<
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
UniversalGemmPipelineAgBgCrPolicy
struct
UniversalGemmPipelineAgBgCrPolicy
{
{
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
...
@@ -22,286 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -22,286 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy
static
constexpr
bool
TransposeC
=
true
;
static
constexpr
bool
TransposeC
=
true
;
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorLoadSize
()
{
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
elements_per_thread
=
MNPerBlock
*
KPerBlock
/
BlockSize
;
if
constexpr
(
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
{
return
(
16
/
sizeof
(
DataType
));
}
else
if
constexpr
(
elements_per_thread
%
(
8
/
sizeof
(
DataType
))
==
0
)
{
return
(
8
/
sizeof
(
DataType
));
}
else
if
constexpr
(
elements_per_thread
%
(
4
/
sizeof
(
DataType
))
==
0
&&
sizeof
(
DataType
)
>=
4
)
{
return
(
4
/
sizeof
(
DataType
));
}
else
if
constexpr
(
elements_per_thread
%
(
2
/
sizeof
(
DataType
))
==
0
&&
sizeof
(
DataType
)
>=
2
)
{
return
(
2
/
sizeof
(
DataType
));
}
else
{
return
1
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
auto
DataTypeSize
=
sizeof
(
ADataType
);
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayoutA
>::
value
)
constexpr
auto
MLdsLayer
=
{
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
?
1
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
:
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
);
make_tuple
(
number
<
KPerBlock
/
KPack
*
MLdsLayer
>
{},
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
number
<
MPerBlock
/
MLdsLayer
>
{},
make_tuple
(
K0
*
number
<
MLdsLayer
>
{},
number
<
MPerBlock
/
MLdsLayer
>
{},
K1
),
number
<
KPack
>
{}),
make_tuple
(
K1
,
number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
MLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
number
<
1
>
{});
a_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
number
<
K0
*
MLdsLayer
>
{})),
a_lds_block_desc_0
,
make_pass_through_transform
(
K1
)),
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
number
<
KPerBlock
/
KPack
*
MLdsLayer
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
constexpr
auto
a_lds_block_desc_ak0_kMLdsLayer_m_ak1
=
transform_tensor_descriptor
(
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
number
<
MLdsLayer
>
{})),
constexpr
auto
a_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
make_pass_through_transform
(
number
<
MPerBlock
/
MLdsLayer
>
{}),
a_lds_block_desc_permuted
,
make_pass_through_transform
(
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MLdsLayer
>
{})),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
make_pass_through_transform
(
number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
a_lds_block_desc_ak0_kMLdsLayer_m_ak1
,
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
)),
make_merge_transform_v3_division_mod
(
constexpr
auto
a_lds_block_desc
=
transform_tensor_descriptor
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
MLdsLayer
>
{}))),
a_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
>
{}),
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
MLdsLayer
>
{})),
make_merge_transform_v3_division_mod
(
return
a_lds_block_desc_m_k
;
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
}
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
else
// ColumnMajor A
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
{
// kfold and mpair dimension is not always required.
return
a_lds_block_desc
;
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
M0
=
get_warp_size
()
*
Problem
::
BlockGemmShape
::
BlockWarps
::
at
(
I0
);
constexpr
auto
M1
=
MPerBlock
/
M0
;
constexpr
auto
KThreadWrite
=
Problem
::
kBlockSize
/
M0
;
constexpr
auto
K0PerThreadWrite
=
K0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
WarpGemm
::
kM
;
constexpr
auto
K0PerThreadRead
=
K0
/
KThreadRead
;
constexpr
auto
kfold
=
(
K1
*
M0
*
sizeof
(
ADataType
)
>
128
)
?
1
:
128
/
(
K1
*
M0
*
sizeof
(
ADataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mpair<=kN0
constexpr
auto
mpair
=
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)
>
128
)
?
1
:
((
128
/
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)))
>
M0
?
M0
:
128
/
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
M1
>
{},
number
<
kfold
*
M0
/
mpair
>
{},
number
<
mpair
>
{},
K1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
M1
>
{},
number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
number
<
mpair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
a_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
M1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
M0
/
mpair
>
{})),
make_pass_through_transform
(
number
<
mpair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
M0
/
mpair
>
{},
number
<
mpair
>
{},
number
<
M1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
a_lds_block_desc_m_k
;
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
auto
DataTypeSize
=
sizeof
(
BDataType
);
constexpr
auto
NLdsLayer
=
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayoutB
>::
value
)
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
{
// NLdsLayer * K0 as logical Bank
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
make_tuple
(
number
<
KPerBlock
/
KPack
*
NLdsLayer
>
{},
?
1
number
<
NPerBlock
/
NLdsLayer
>
{},
:
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
);
number
<
KPack
>
{}),
;
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
NLdsLayer
>
{},
number
<
1
>
{}),
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
number
<
KPack
>
{},
make_tuple
(
K0
*
number
<
NLdsLayer
>
{},
number
<
NPerBlock
/
NLdsLayer
>
{},
K1
),
number
<
1
>
{});
make_tuple
(
K1
,
number
<
KPerBlock
*
NLdsLayer
>
{},
I1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
b_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
KPerBlock
/
KPack
*
NLdsLayer
>
{})),
number
<
K0
*
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
constexpr
auto
b_lds_block_desc_bk0_kNLdsLayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
number
<
NLdsLayer
>
{})),
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_kNLdsLayer_n_bk1
,
b_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_merge_transform_v3_division_mod
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{})),
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{}))),
make_merge_transform_v3_division_mod
(
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
>
{}),
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
return
b_lds_block_desc
;
}
else
// RowMajor B
{
constexpr
auto
N0
=
get_warp_size
()
*
Problem
::
BlockGemmShape
::
BlockWarps
::
at
(
I1
);
constexpr
auto
N1
=
NPerBlock
/
N0
;
constexpr
auto
KThreadWrite
=
Problem
::
kBlockSize
/
N0
;
constexpr
auto
K0PerThreadWrite
=
K0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
WarpGemm
::
kN
;
constexpr
auto
K0PerThreadRead
=
K0
/
KThreadRead
;
constexpr
auto
kfold
=
(
K1
*
N0
*
sizeof
(
BDataType
)
>
128
)
?
1
:
128
/
(
K1
*
N0
*
sizeof
(
BDataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=kN0
constexpr
auto
npair
=
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)
>
128
)
?
1
:
((
128
/
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)))
>
N0
?
N0
:
128
/
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{},
number
<
npair
>
{},
K1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
N0
/
npair
>
{},
number
<
npair
>
{},
number
<
N1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
b_lds_block_desc_n_k
;
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -334,69 +180,268 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -334,69 +180,268 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
typename
Problem
::
BDataType
,
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
constexpr
index_t
K0
=
KPerBlock
/
K1
;
{
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
total_pixels
%
M1
==
0
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
static_assert
(
KPack
%
K3
==
0
);
return
make_static_tile_distribution
(
constexpr
index_t
K2
=
KPack
/
K3
;
tile_distribution_encoding
<
sequence
<
1
>
,
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
)
==
0
)
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
{
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
sequence
<
1
,
2
>
,
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
sequence
<
0
,
1
>>
{});
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
A
DataType
,
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
B
DataType
>
;
typename
Problem
::
B
DataType
,
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
B
Layout
>
;
typename
Problem
::
CDataType
,
Problem
::
Block
GemmShape
::
WarpTile
::
at
(
I0
),
constexpr
index_t
BlockSize
=
Problem
::
k
Block
Size
;
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
TransposeC
>
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
constexpr
index_t
K3
=
total_pixels
/
N1
;
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
constexpr
index_t
kKPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
return
make_static_tile_distribution
(
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
tile_distribution_encoding
<
sequence
<
1
>
,
{
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
return
make_static_tile_distribution
(
sequence
<
0
,
1
>>
{});
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
c87aa6c8
...
@@ -3,19 +3,23 @@
...
@@ -3,19 +3,23 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
bool
kPad
A
_
,
template
<
bool
kPad
M
_
,
bool
kPad
B
_
,
bool
kPad
N
_
,
bool
kPad
C
_
,
bool
kPad
K
_
,
typename
ALayout_
,
typename
ALayout_
,
typename
BLayout_
,
typename
BLayout_
,
typename
CLayout_
>
typename
CLayout_
>
struct
TileGemmTraits
struct
TileGemmTraits
{
{
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
int
_VectorSize
=
16
;
using
ALayout
=
ALayout_
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
BLayout
=
BLayout_
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
c87aa6c8
...
@@ -10,114 +10,134 @@
...
@@ -10,114 +10,134 @@
namespace
ck_tile
{
namespace
ck_tile
{
// fp16
// fp16
using
WarpGemmMfmaF16F16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K16
=
using
WarpGemmMfmaF16F16F32M
32N32K8
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M
32N32K16
=
using
WarpGemmMfmaF16F16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
IterateK
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
,
2
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
using
WarpGemmMfmaF16F16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
1
>>
;
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
=
using
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
=
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
// bf16
// bf16
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16
=
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
1
>>
;
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
=
using
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
=
using
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
=
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
// fp8
// fp8
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
template
<
index_t
swizzle_factor
=
2
>
template
<
index_t
swizzle_factor
=
2
>
using
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
=
using
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
WGAttrCtlEnum
::
Default_
>
,
2
,
2
,
swizzle_factor
>>
;
swizzle_factor
>>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
c87aa6c8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -21,9 +21,12 @@ struct WarpGemmAtrributeMfma
...
@@ -21,9 +21,12 @@ struct WarpGemmAtrributeMfma
using
BVecType
=
typename
Impl
::
BVecType
;
using
BVecType
=
typename
Impl
::
BVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
...
@@ -51,10 +54,13 @@ struct WarpGemmAtrributeMfma
...
@@ -51,10 +54,13 @@ struct WarpGemmAtrributeMfma
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
Impl
{}(
c_vec
,
a_vec
,
b_vec
);
Impl
{}(
c_vec
,
a_vec
,
b_vec
,
bool_constant
<
post_nop_
>
{}
);
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -81,9 +87,12 @@ struct WarpGemmAtrributeMfmaIterateK
...
@@ -81,9 +87,12 @@ struct WarpGemmAtrributeMfmaIterateK
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
*
kKIter
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
...
@@ -111,8 +120,11 @@ struct WarpGemmAtrributeMfmaIterateK
...
@@ -111,8 +120,11 @@ struct WarpGemmAtrributeMfmaIterateK
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
@@ -122,10 +134,33 @@ struct WarpGemmAtrributeMfmaIterateK
...
@@ -122,10 +134,33 @@ struct WarpGemmAtrributeMfmaIterateK
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
});
}
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
...
@@ -164,9 +199,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
...
@@ -164,9 +199,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
using
BVecType
=
typename
Impl
::
AVecType
;
using
BVecType
=
typename
Impl
::
AVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
...
@@ -194,11 +232,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
...
@@ -194,11 +232,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
// swap A and B
// swap A and B
Impl
{}(
c_vec
,
b_vec
,
a_vec
);
Impl
{}(
c_vec
,
b_vec
,
a_vec
,
bool_constant
<
post_nop_
>
{}
);
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -222,9 +263,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
...
@@ -222,9 +263,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
using
BVecType
=
typename
Impl
::
AVecType
;
using
BVecType
=
typename
Impl
::
AVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
...
@@ -255,12 +299,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
...
@@ -255,12 +299,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
sequence
<
2
,
2
>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
// swap A and B
// swap A and B
Impl
{}(
c_vec
,
b_vec
,
a_vec
);
Impl
{}(
c_vec
,
b_vec
,
a_vec
,
bool_constant
<
post_nop_
>
{}
);
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -287,9 +334,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
...
@@ -287,9 +334,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
*
kKIter
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
...
@@ -316,9 +366,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
...
@@ -316,9 +366,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
sequence
<
2
,
2
>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
@@ -328,10 +381,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
...
@@ -328,10 +381,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
});
}
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
...
@@ -372,10 +449,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
...
@@ -372,10 +449,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
...
@@ -429,8 +509,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
...
@@ -429,8 +509,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
#endif
#endif
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
@@ -440,10 +523,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
...
@@ -440,10 +523,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
});
}
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
...
@@ -483,10 +589,13 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
...
@@ -483,10 +589,13 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
...
@@ -518,8 +627,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
...
@@ -518,8 +627,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
@@ -529,10 +641,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
...
@@ -529,10 +641,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
});
}
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
c87aa6c8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -7,12 +7,68 @@
...
@@ -7,12 +7,68 @@
namespace
ck_tile
{
namespace
ck_tile
{
// TODO: refactor warp-gemm
// currently there is a discrepency for vav/vva if we need transpose C/D
// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
// because we swap the A/B pointer in _impl code (but not known this info here)
enum
class
WGAttrCtlEnum
{
Default_
=
0
,
Raw_vvv
=
1
,
// c-vgpr, a-vgpr, b-vgpr
Raw_vaa
=
2
,
// c-vgpr, a-agpr, b-agpr
Raw_vav
=
3
,
// c-vgpr, a-agpr, b-vgpr
Raw_vva
=
4
,
// c-vgpr, a-vgpr, b-agpr
Raw_avv
=
5
,
// c-agpr, a-vgpr, b-vgpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
"s_nop 3" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
} \
else \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
}
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
{ \
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
}
// FP16
// FP16
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
struct
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
{
{
using
ADataType
=
fp16_t
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
BDataType
=
fp16_t
;
using
ADataType
=
fp16_t
;
using
CDataType
=
float
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
...
@@ -33,16 +89,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
...
@@ -33,16 +89,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static
constexpr
index_t
kCM1PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_32x32x8f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
#else
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
#endif
}
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -52,18 +115,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
...
@@ -52,18 +115,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
return
bit_cast
<
CVecType
>
(
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#else
#else
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
return
CVecType
{
0.
f
};
#endif
#endif
}
}
};
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
struct
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{
{
using
ADataType
=
fp16_t
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
BDataType
=
fp16_t
;
using
ADataType
=
fp16_t
;
using
CDataType
=
float
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
...
@@ -84,16 +149,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
...
@@ -84,16 +149,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static
constexpr
index_t
kCM1PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_16x16x16f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
#else
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
#endif
}
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -103,19 +175,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
...
@@ -103,19 +175,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
return
bit_cast
<
CVecType
>
(
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
#else
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
return
CVecType
{
0.
f
};
#endif
#endif
}
}
};
};
// Bf16
// Bf16
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
{
{
using
ADataType
=
bf16_t
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
BDataType
=
bf16_t
;
using
ADataType
=
bf16_t
;
using
CDataType
=
float
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
...
@@ -136,28 +210,35 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...
@@ -136,28 +210,35 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static
constexpr
index_t
kCM1PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_32x32x8bf16_1k"
,
Ctrl
)
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
c_vec
,
0
,
0
,
0
,
0
,
0
);
0
);
});
});
#else
#else
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
#endif
}
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -181,18 +262,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...
@@ -181,18 +262,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
});
});
return
c_vec
;
return
c_vec
;
#else
#else
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
return
CVecType
{
0.
f
};
#endif
#endif
}
}
};
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
{
{
using
ADataType
=
bf16_t
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
BDataType
=
bf16_t
;
using
ADataType
=
bf16_t
;
using
CDataType
=
float
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
...
@@ -213,28 +296,34 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
...
@@ -213,28 +296,34 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static
constexpr
index_t
kCM1PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_16x16x16bf16_1k"
,
Ctrl
)
{
#if defined(__gfx90a__) || defined(__gfx94__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x8bf16
(
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x8bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
c_vec
,
0
,
0
,
0
,
0
,
0
);
0
);
});
});
#else
#else
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
#endif
}
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -258,20 +347,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
...
@@ -258,20 +347,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
});
});
return
c_vec
;
return
c_vec
;
#else
#else
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
return
CVecType
{
0.
f
};
#endif
#endif
}
}
};
};
// FP8
// FP8
template
<
typename
AType_
,
typename
BType_
>
template
<
typename
AType_
,
typename
BType_
,
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
struct
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
{
{
using
ADataType
=
AType_
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
BDataType
=
BType_
;
using
ADataType
=
AType_
;
using
CDataType
=
float
;
using
BDataType
=
BType_
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
ADataType
,
8
>
;
using
AVecType
=
ext_vector_t
<
ADataType
,
8
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
8
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
8
>
;
...
@@ -292,38 +382,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...
@@ -292,38 +382,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static
constexpr
index_t
kCM1PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vvv
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vaa
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vav
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vva
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
}
else
{
#if defined(__gfx94__)
#if defined(__gfx94__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__) || defined(__gfx90a__)
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
a_f32
=
float
a_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
float
b_f32
=
float
b_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
});
#else
#else
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
#endif
}
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -356,20 +528,97 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...
@@ -356,20 +528,97 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
});
});
return
c_vec
;
return
c_vec
;
#else
#else
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
return
CVecType
{
0.
f
};
#endif
#endif
}
}
};
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
=
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
Ctrl_
>
;
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
=
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
bf8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
bf8_t
,
Ctrl_
>
;
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
=
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
fp8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
fp8_t
,
Ctrl_
>
;
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
=
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
bf8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
bf8_t
,
Ctrl_
>
;
// int8
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int32_t
;
using
AVecType
=
ext_vector_t
<
ADataType
,
8
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
8
>
;
using
CVecType
=
ext_vector_t
<
CDataType
,
16
>
;
static
constexpr
index_t
kM
=
32
;
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
static
constexpr
index_t
kABKPerLane
=
8
;
static
constexpr
index_t
kCMLane
=
2
;
static
constexpr
index_t
kCNLane
=
32
;
static
constexpr
index_t
kCM0PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_i32_32x32x16_i8"
,
Ctrl
)
else
{
#if defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_i32_32x32x8i8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
a_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
float
b_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
CVecType
c_vec
{
0
};
operator
()(
c_vec
,
a_vec
,
b_vec
);
return
c_vec
;
}
};
#undef DISPATCH_MFMA_
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
c87aa6c8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
...
@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
// clang-format off
// clang-format off
// fp16
// fp16
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
// bf16
// bf16
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
// fp8
// fp8
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
// clang-format on
// clang-format on
}
// namespace impl
}
// namespace impl
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
View file @
c87aa6c8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -14,6 +14,11 @@ struct WarpGemmImpl
...
@@ -14,6 +14,11 @@ struct WarpGemmImpl
static
constexpr
index_t
kM
=
WarpGemmAttribute
::
kM
;
static
constexpr
index_t
kM
=
WarpGemmAttribute
::
kM
;
static
constexpr
index_t
kN
=
WarpGemmAttribute
::
kN
;
static
constexpr
index_t
kN
=
WarpGemmAttribute
::
kN
;
static
constexpr
index_t
kK
=
WarpGemmAttribute
::
kK
;
static
constexpr
index_t
kK
=
WarpGemmAttribute
::
kK
;
/// @brief The number of elements in K dimension processed by single thread in wavefront.
///
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
/// In such situation this value reflects this fact.
static
constexpr
index_t
kKPerThread
=
WarpGemmAttribute
::
kKPerThread
;
using
ADataType
=
typename
WarpGemmAttribute
::
ADataType
;
using
ADataType
=
typename
WarpGemmAttribute
::
ADataType
;
using
BDataType
=
typename
WarpGemmAttribute
::
BDataType
;
using
BDataType
=
typename
WarpGemmAttribute
::
BDataType
;
...
@@ -31,11 +36,21 @@ struct WarpGemmImpl
...
@@ -31,11 +36,21 @@ struct WarpGemmImpl
using
BWarpTensor
=
static_distributed_tensor
<
BDataType
,
BWarpDstr
>
;
using
BWarpTensor
=
static_distributed_tensor
<
BDataType
,
BWarpDstr
>
;
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
CK_TILE_DEVICE
void
operator
()(
CWarpTensor
&
c
,
const
AWarpTensor
&
a
,
const
BWarpTensor
&
b
)
const
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
get_num_of_access
()
{
{
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
return
WarpGemmAttribute_
::
get_num_of_access
();
using
BVec
=
ext_vector_t
<
BDataType
,
BWarpTensor
::
get_thread_buffer_size
()
>
;
}
using
CVec
=
ext_vector_t
<
CDataType
,
CWarpTensor
::
get_thread_buffer_size
()
>
;
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
{
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
CTensor
,
CWarpTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
AWarpTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BWarpTensor
>
);
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I0
=
number
<
0
>
{};
...
@@ -44,18 +59,49 @@ struct WarpGemmImpl
...
@@ -44,18 +59,49 @@ struct WarpGemmImpl
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
);
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
,
bool_constant
<
post_nop_
>
{}
);
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
}
CK_TILE_DEVICE
auto
operator
()(
const
AWarpTensor
&
a
,
const
BWarpTensor
&
b
)
const
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
index_t
i_subk
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
number
<
i_subk
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
CWarpTensor
c
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
const
auto
a_vec
=
a
.
get_thread_buffer
().
template
get_as
<
AVec
>()[
I0
];
using
BVec
=
ext_vector_t
<
BDataType
,
BWarpTensor
::
get_thread_buffer_size
()
>
;
const
auto
b_vec
=
b
.
get_thread_buffer
().
template
get_as
<
BVec
>()[
I0
];
using
CVec
=
ext_vector_t
<
CDataType
,
CWarpTensor
::
get_thread_buffer_size
()
>
;
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
// c_vec += a_vec * b_vec
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
,
number
<
i_subk
>
{},
bool_constant
<
post_nop_
>
{});
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
template
<
typename
ATensor
,
typename
BTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ATensor
&
a
,
const
BTensor
&
b
)
const
{
using
CTensor
=
CWarpTensor
;
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
AWarpTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BWarpTensor
>
);
CTensor
c
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I0
=
number
<
0
>
{};
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
c87aa6c8
...
@@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass
auto
[
mean
,
var
]
=
block_welford
(
acc
,
cur_count
,
max_count
);
auto
[
mean
,
var
]
=
block_welford
(
acc
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{}
);
// compute inv-std
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
auto
inv_std
=
tile_elementwise_in
(
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
c87aa6c8
...
@@ -35,6 +35,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -35,6 +35,7 @@ struct Layernorm2dFwdPipelineTwoPass
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
@@ -137,15 +138,22 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -137,15 +138,22 @@ struct Layernorm2dFwdPipelineTwoPass
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{}
);
// compute inv-std
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
));
if
(
kFastFDiv
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
*
__builtin_amdgcn_rcpf
(
sqrt
(
v_
+
epsilon
));
}
else
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
sqrt
(
v_
+
epsilon
);
}
},
},
var
);
var
);
if
constexpr
(
kSaveMean
)
if
constexpr
(
kSaveMean
)
store_tile
(
mean_window
,
cast_tile
<
MeanDataType
>
(
mean
));
store_tile
(
mean_window
,
cast_tile
<
MeanDataType
>
(
mean
));
if
constexpr
(
kSaveInvStd
)
if
constexpr
(
kSaveInvStd
)
...
...
include/ck_tile/ops/smoothquant.hpp
View file @
c87aa6c8
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#include "ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp"
#include "ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp"
#include "ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp"
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment