Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c881136b
Commit
c881136b
authored
Jan 01, 2025
by
Po Yen Chen
Browse files
Merge branch 'develop' into ck_tile/support-vllm-kcache-layout
parents
c5e8e14f
4e076909
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
986 additions
and
224 deletions
+986
-224
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
...litkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
+226
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+29
-7
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+14
-41
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+0
-2
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+1
-1
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+210
-37
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
...de/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
+9
-4
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
...ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
+29
-15
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
+29
-15
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+25
-7
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+120
-44
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+8
-6
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+2
-0
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+16
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+259
-44
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
0 → 100644
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
struct
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
{
using
BasePolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
// this should align with MakeQDramTileDistribution()
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
return
min
(
ElemPerThread
,
MaxVectorSize
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOacc
()
{
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
return
static_cast
<
index_t
>
(
16
/
sizeof
(
OaccDataType
));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
kMaxVecLoad
=
min
(
ElemPerThread
,
MaxVectorSize
);
constexpr
index_t
KPerThread
=
kMaxVecLoad
;
constexpr
index_t
KThreads
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
KThreads
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
MThreadPerWarp
*
NumWarps
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
NumWarps
,
MThreadPerWarp
>
,
sequence
<
KThreads
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegTileDistribution
()
{
return
BasePolicy
::
template
MakeQDramTileDistribution
<
Problem
,
BlockGemm
>();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackQ
()
{
// TODO: this is for 3d layout
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
static_cast
<
index_t
>
(
16
/
sizeof
(
QDataType
));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
kKPack
=
min
(
ElemPerThread
,
GetSmemKPackQ
<
Problem
>
());
constexpr
auto
q_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kMPerBlock
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
kKPack
>
{},
number
<
1
>
{});
constexpr
auto
q_lds_block_desc
=
transform_tensor_descriptor
(
q_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kMPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
q_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemNPackS
()
{
using
SDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
return
static_cast
<
index_t
>
(
16
/
sizeof
(
SDataType
));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPack
=
GetSmemNPackS
<
Problem
>
();
constexpr
auto
s_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
kNPack
>
{},
number
<
kMPerBlock
>
{},
number
<
kNPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
kNPack
>
{},
number
<
kNPack
>
{},
number
<
1
>
{}),
number
<
kNPack
>
{},
number
<
1
>
{});
constexpr
auto
s_lds_block_desc
=
transform_tensor_descriptor
(
s_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kMPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
kNPack
>
{},
number
<
kNPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
s_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSRegTileDistribution
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetKVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static_assert
(
MWarp
==
1
,
"Check failed!"
);
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kTileK
=
Problem
::
BlockFmhaShape
::
kN0
;
// K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
constexpr
index_t
K3
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K2
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K1
=
kKPerBlock
/
(
K2
*
K3
);
constexpr
index_t
K0
=
kTileK
/
kKPerBlock
;
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
auto
s2_block_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
2
,
2
,
2
>
,
sequence
<
0
,
0
,
1
,
3
>>
{};
constexpr
auto
s2_block_dstr
=
make_static_tile_distribution
(
s2_block_dstr_encoding
);
return
s2_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
return
MakeQLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
QDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
()
{
return
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
KDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeV
()
{
return
MakeVLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
VDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeS
()
{
return
MakeSLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
SaccDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
max
(
GetSmemSizeQ
<
Problem
>
(),
GetSmemSizeK
<
Problem
>
())
+
max
(
GetSmemSizeV
<
Problem
>
(),
GetSmemSizeS
<
Problem
>
());
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
c881136b
...
@@ -106,28 +106,43 @@ struct BlockFmhaFwdSplitKVPipelineProblem
...
@@ -106,28 +106,43 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
};
// extract tile size attributes to remove dependency on traits
template
<
typename
OaccDataType_
,
ck_tile
::
index_t
kN1_
>
struct
BlockFmhaSplitKVCombinePipelineTileSizes
{
static
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
OaccDataType_
);
static
constexpr
index_t
kN1
=
kN1_
;
static
constexpr
index_t
NThreads
=
kN1
/
MaxVectorSize
;
static
constexpr
index_t
kM0
=
get_warp_size
()
/
NThreads
;
// MThreadPerWarp
};
template
<
typename
LSEDataType_
,
template
<
typename
LSEDataType_
,
typename
OaccDataType_
,
typename
OaccDataType_
,
typename
ODataType_
,
typename
ODataType_
,
index_t
HeadDimV_
,
index_t
HeadDimV_
,
index_t
kM0_
,
index_t
kN1_
,
bool
kIsGroupMode_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kN1_
,
typename
Traits_
>
typename
Traits_
>
struct
BlockFmhaSplitKVCombinePipelineProblem
struct
BlockFmhaSplitKVCombinePipelineProblem
:
BlockFmhaSplitKVCombinePipelineTileSizes
<
OaccDataType_
,
kN1_
>
{
{
using
BaseType
=
BlockFmhaSplitKVCombinePipelineTileSizes
<
OaccDataType_
,
kN1_
>
;
using
LSEDataType
=
remove_cvref_t
<
LSEDataType_
>
;
using
LSEDataType
=
remove_cvref_t
<
LSEDataType_
>
;
using
OaccDataType
=
remove_cvref_t
<
OaccDataType_
>
;
using
OaccDataType
=
remove_cvref_t
<
OaccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kNumWarps
=
kM0_
/
(
get_warp_size
()
/
4
);
static_assert
(
std
::
is_same_v
<
LSEDataType
,
OaccDataType
>
);
static
constexpr
index_t
kBlockSize
=
kNumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
static
constexpr
index_t
kM0
=
kM0_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kN1
=
kN1_
;
using
BaseType
::
kM0
;
using
BaseType
::
kN1
;
static_assert
(
kN1
<=
kHeadDimV
&&
kHeadDimV
%
kN1
==
0
);
// attributes from traits
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
...
@@ -136,6 +151,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem
...
@@ -136,6 +151,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kMaxSplits
=
Traits
::
kMaxSplits
;
static
constexpr
index_t
kMaxSplits
=
Traits
::
kMaxSplits
;
static_assert
(
8
<=
kMaxSplits
);
static
constexpr
index_t
kNumWarps
=
4
;
// always use 4 warps for each workgroup
static
constexpr
index_t
kBlockSize
=
kNumWarps
*
get_warp_size
();
static_assert
(
get_warp_size
()
<=
(
kM0
*
kMaxSplits
)
&&
(
kM0
*
kMaxSplits
)
%
get_warp_size
()
==
0
);
};
};
template
<
typename
QDataType_
,
template
<
typename
QDataType_
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
c881136b
...
@@ -41,52 +41,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -41,52 +41,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
{
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
return
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
return
min
(
MaxVectorSize
,
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
);
}
}
template
<
typename
Problem
,
typename
BlockGemm
>
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQDramTileDistribution
()
{
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
return
BlockGemm
::
template
MakeABlockTileDistribution
<
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
Problem
::
BlockFmhaShape
::
kM0
,
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kSubQKHeaddim
;
constexpr
index_t
K2
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K1
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K0
=
kKPerBlock
/
(
K1
*
K2
);
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
if
constexpr
(
1
<
Problem
::
kNumGemm0Warps
)
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
else
{
static_assert
(
MWarp
==
1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
16
||
WarpGemmM
==
32
);
static_assert
(
WarpGemmM
==
4
||
WarpGemmM
==
16
||
WarpGemmM
==
32
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
...
@@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
{
if
constexpr
(
WarpGemmM
==
32
)
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
//
WarpGemmM == 16
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaF16F16F32M4N64K16
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
...
@@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
{
if
constexpr
(
WarpGemmM
==
32
)
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
//
WarpGemmM == 16
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaBf16Bf16F32M4N64K16
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
c881136b
...
@@ -43,8 +43,6 @@ struct TileFmhaShape
...
@@ -43,8 +43,6 @@ struct TileFmhaShape
static
constexpr
index_t
NumWarps
=
max
(
NumGemm0Warps
,
NumGemm1Warps
);
static
constexpr
index_t
NumWarps
=
max
(
NumGemm0Warps
,
NumGemm1Warps
);
static_assert
(
std
::
is_same_v
<
Gemm0WarpTile
,
Gemm1WarpTile
>
);
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kK0
=
BlockTile
::
at
(
number
<
2
>
{});
// tile size along qk gemm unroll
static
constexpr
index_t
kK0
=
BlockTile
::
at
(
number
<
2
>
{});
// tile size along qk gemm unroll
...
...
include/ck_tile/ops/fused_moe.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
c881136b
...
@@ -130,7 +130,8 @@ struct MoeSortingKernel
...
@@ -130,7 +130,8 @@ struct MoeSortingKernel
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
{
{
const
auto
blocks
=
BlockSize
(
h
);
const
auto
blocks
=
BlockSize
(
h
);
return
((
blocks
.
x
+
1
)
*
h
.
num_experts
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
// usually num_experts is power of 2, we pad 1 dword here for the row-size
return
((
blocks
.
x
+
1
)
*
(
h
.
num_experts
+
1
)
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
}
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
...
@@ -154,6 +155,75 @@ struct MoeSortingKernel
...
@@ -154,6 +155,75 @@ struct MoeSortingKernel
return
k
;
return
k
;
}
}
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
template
<
typename
data_t
,
int
wave_size
>
__device__
inline
void
wave_cumsum
(
data_t
&
thread_data
)
const
{
// wave_size must be power of 2
constexpr
int
row_mask
=
0xf
;
constexpr
int
bank_mask
=
0xf
;
constexpr
bool
bound_ctrl
=
true
;
// ! out-of-bound is zero !
auto
reduce_op
=
[
&
](
auto
x_
,
auto
y_
)
{
return
x_
+
y_
;
};
if
constexpr
(
wave_size
>
1
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x111
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:1
}
if
constexpr
(
wave_size
>
2
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x112
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:2
}
if
constexpr
(
wave_size
>
4
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x114
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:4
}
if
constexpr
(
wave_size
>
8
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x118
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:8
}
if
constexpr
(
wave_size
>
16
)
{
// now row-0, row-0+row-1, row-1+row-2, row-2+row-3
int
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(((
__lane_id
()
&
0x30
)
-
1
)
<<
2
,
__builtin_bit_cast
(
int
,
thread_data
));
v_remote_tmp
=
__lane_id
()
>=
16
?
v_remote_tmp
:
0
;
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
v_remote_tmp
));
}
if
constexpr
(
wave_size
>
32
)
{
// lane-id 48...63->31
int
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(((
__lane_id
()
&
0x30
)
-
17
)
<<
2
,
__builtin_bit_cast
(
int
,
thread_data
));
v_remote_tmp
=
__lane_id
()
>=
32
?
v_remote_tmp
:
0
;
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
v_remote_tmp
));
}
}
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
{
{
return
row
*
total_col
+
col
;
return
row
*
total_col
+
col
;
...
@@ -187,48 +257,124 @@ struct MoeSortingKernel
...
@@ -187,48 +257,124 @@ struct MoeSortingKernel
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
num_experts
;
// 1: (num_experts + 1)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
(
num_experts
+
1
);
// 1: (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
{
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
i
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
)]
=
0
;
}
}
#pragma unroll Problem_::InternalLoadUnroll
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
{
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
topk_id
[
i
])];
++
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
topk_id
[
i
])];
}
}
__syncthreads
();
__syncthreads
();
#if 1
if
(
tid
<
num_experts
)
if
(
tid
<
num_experts
)
{
{
tokens_cnts
[
calc_index
(
num_experts
,
0
,
tid
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
++
i
)
index_t
local_c
[
8
];
index_t
prev_c
=
0
;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
i
+=
8
)
{
{
tokens_cnts
[
calc_index
(
num_experts
,
i
,
tid
)]
+=
local_c
[
0
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)];
tokens_cnts
[
calc_index
(
num_experts
,
i
-
1
,
tid
)];
local_c
[
1
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)];
local_c
[
2
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)];
local_c
[
3
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)];
local_c
[
4
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)];
local_c
[
5
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)];
local_c
[
6
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)];
local_c
[
7
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)];
local_c
[
0
]
+=
prev_c
;
local_c
[
1
]
+=
local_c
[
0
];
local_c
[
2
]
+=
local_c
[
1
];
local_c
[
3
]
+=
local_c
[
2
];
local_c
[
4
]
+=
local_c
[
3
];
local_c
[
5
]
+=
local_c
[
4
];
local_c
[
6
]
+=
local_c
[
5
];
local_c
[
7
]
+=
local_c
[
6
];
prev_c
=
local_c
[
7
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)]
=
local_c
[
0
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)]
=
local_c
[
1
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)]
=
local_c
[
2
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)]
=
local_c
[
3
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)]
=
local_c
[
4
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)]
=
local_c
[
5
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)]
=
local_c
[
6
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)]
=
local_c
[
7
];
}
}
}
}
#else
// __syncthreads();
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
if
(
tid
==
0
)
{
{
cumsum
[
0
]
=
0
;
if
(
tid
<
num_experts
)
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
i
+=
8
)
{
index_t
local_c
[
8
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
local_c
[
j
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)];
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
wave_cumsum
<
int
,
64
>
(
local_c
[
j
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)]
=
local_c
[
j
];
}
}
}
#endif
__syncthreads
();
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
(
tid
==
0
)
{
{
auto
current_units
=
[
&
]()
{
cumsum
[
0
]
=
0
;
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)]
+
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
unit_size_mdiv
.
divisor
-
1
;
{
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
auto
current_units
=
[
&
]()
{
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
i
-
1
)]
+
}();
unit_size_mdiv
.
divisor
-
1
;
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
current_units
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
}();
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
current_units
;
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
// for simplicity, not check experts here.
int
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
int
blocks_pers_expert
=
unit_size_mdiv
.
div
(
local_cnt
+
unit_size_mdiv
.
divisor
-
1
);
int
padded_tokens_per_expert
=
max
(
blocks_pers_expert
,
1
)
*
unit_size_mdiv
.
divisor
;
int
local_cumsum
=
padded_tokens_per_expert
;
wave_cumsum
<
int
,
64
>
(
local_cumsum
);
if
(
tid
==
(
num_experts
-
1
))
{
cumsum
[
0
]
=
0
;
*
p_total_tokens_post_pad
=
local_cumsum
;
}
if
(
tid
<
num_experts
)
{
cumsum
[
tid
+
1
]
=
local_cumsum
;
}
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
}
__syncthreads
();
__syncthreads
();
if
(
tid
<
num_experts
)
if
(
tid
<
num_experts
)
{
{
for
(
int
i
=
cumsum
[
tid
];
i
<
cumsum
[
tid
+
1
];
i
+=
unit_size_mdiv
.
divisor
)
int
e_start
=
cumsum
[
tid
];
int
e_end
=
cumsum
[
tid
+
1
];
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
}
}
...
@@ -238,8 +384,8 @@ struct MoeSortingKernel
...
@@ -238,8 +384,8 @@ struct MoeSortingKernel
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
{
index_t
expert_id
=
topk_id
[
i
];
index_t
expert_id
=
topk_id
[
i
];
index_t
rank_post_pad
=
index_t
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)];
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)]
+
cumsum
[
expert_id
];
index_t
rank_post_pad
=
local_cnt
+
cumsum
[
expert_id
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t
curr_token_id
,
curr_topk_id
;
uint32_t
curr_token_id
,
curr_topk_id
;
topk_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_topk_id
);
topk_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_topk_id
);
...
@@ -247,27 +393,54 @@ struct MoeSortingKernel
...
@@ -247,27 +393,54 @@ struct MoeSortingKernel
#else
#else
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
#endif
#endif
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)]
=
local_cnt
+
1
;
}
}
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
(
tid
<
num_experts
)
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
{
if
(
tid
<
num_experts
)
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
tid
)];
while
(
expert_offset
<
cumsum
[
tid
+
1
])
{
{
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
index_t
expert_end
=
cumsum
[
tid
+
1
];
while
(
expert_offset
<
expert_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
expert_offset
]
=
p_sorted_token_ids
[
expert_offset
]
=
MOE_SORTING_MOCK_ID
(
prefill_token
,
topk_mdiv
.
divisor
);
MOE_SORTING_MOCK_ID
(
prefill_token
,
topk_mdiv
.
divisor
);
#else
#else
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
#endif
#endif
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
++
;
expert_offset
++
;
}
}
}
}
}
else
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
// TODO: only support expert-tile like 8, 16, 32
static
constexpr
index_t
experts_per_wave
=
warpSize
/
Problem
::
ExpertTile
;
{
index_t
eid
=
tid
/
experts_per_wave
;
index_t
expert_offset
=
cumsum
[
eid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
eid
)]
+
tid
%
experts_per_wave
;
index_t
expert_end
=
cumsum
[
eid
+
1
];
if
(
eid
<
num_experts
)
{
while
(
expert_offset
<
expert_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
expert_offset
]
=
MOE_SORTING_MOCK_ID
(
prefill_token
,
topk_mdiv
.
divisor
);
#else
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
#endif
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
+=
experts_per_wave
;
}
}
}
}
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
...
...
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
View file @
c881136b
...
@@ -9,15 +9,20 @@
...
@@ -9,15 +9,20 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
InternalLoadUnroll_
>
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
InternalLoadUnroll_
,
index_t
ExpertTile_
=
0
>
struct
MoeSortingProblem
struct
MoeSortingProblem
{
{
// TODO: this kernel only support warp per row
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
InternalLoadUnroll
=
InternalLoadUnroll_
;
static
constexpr
index_t
InternalLoadUnroll
=
InternalLoadUnroll_
;
// TODO: need better design(like tile size)
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
View file @
c881136b
...
@@ -65,14 +65,6 @@ struct BlockGemmARegBSmemCRegOneWarpV1
...
@@ -65,14 +65,6 @@ struct BlockGemmARegBSmemCRegOneWarpV1
const
index_t
iNWarp
=
0
;
const
index_t
iNWarp
=
0
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
>
,
sequence
<
NIterPerWarp
>>
,
tuple
<
sequence
<
MIterPerWarp
>
,
sequence
<
NIterPerWarp
>>
,
...
@@ -81,19 +73,14 @@ struct BlockGemmARegBSmemCRegOneWarpV1
...
@@ -81,19 +73,14 @@ struct BlockGemmARegBSmemCRegOneWarpV1
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
// distribution
auto
a_block_tensor
=
auto
a_block_tensor
=
make_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
m
ake
_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
a_block_dstr
);
M
ake
ABlockTileDistribution
()
);
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
...
@@ -187,6 +174,33 @@ struct BlockGemmARegBSmemCRegOneWarpV1
...
@@ -187,6 +174,33 @@ struct BlockGemmARegBSmemCRegOneWarpV1
});
});
}
}
template
<
index_t
MPerBlock
=
BlockGemmShape
::
kM
,
index_t
KPerBlock
=
BlockGemmShape
::
kK
>
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockTileDistribution
()
{
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
return
make_static_tile_distribution
(
a_block_dstr_encode
);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
View file @
c881136b
...
@@ -59,14 +59,6 @@ struct BlockGemmARegBSmemCRegV2
...
@@ -59,14 +59,6 @@ struct BlockGemmARegBSmemCRegV2
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
...
@@ -75,19 +67,14 @@ struct BlockGemmARegBSmemCRegV2
...
@@ -75,19 +67,14 @@ struct BlockGemmARegBSmemCRegV2
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
// distribution
auto
a_block_tensor
=
auto
a_block_tensor
=
make_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
m
ake
_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
a_block_dstr
);
M
ake
ABlockTileDistribution
()
);
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
...
@@ -182,6 +169,33 @@ struct BlockGemmARegBSmemCRegV2
...
@@ -182,6 +169,33 @@ struct BlockGemmARegBSmemCRegV2
});
});
}
}
template
<
index_t
MPerBlock
=
BlockGemmShape
::
kM
,
index_t
KPerBlock
=
BlockGemmShape
::
kK
>
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockTileDistribution
()
{
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
return
make_static_tile_distribution
(
a_block_dstr_encode
);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
...
...
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
c881136b
...
@@ -67,9 +67,10 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
...
@@ -67,9 +67,10 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
KernelArgs
=
BatchedGemmKernelArgs
;
using
KernelArgs
=
BatchedGemmKernelArgs
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_count
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
,
index_t
batch_count
)
{
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
batch_count
);
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
*
batch_count
);
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
...
@@ -85,7 +86,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
...
@@ -85,7 +86,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
hostArgs
.
K
,
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
},
hostArgs
.
stride_C
,
hostArgs
.
k_batch
},
hostArgs
.
batch_stride_A
,
hostArgs
.
batch_stride_A
,
hostArgs
.
batch_stride_B
,
hostArgs
.
batch_stride_B
,
hostArgs
.
batch_stride_C
,
hostArgs
.
batch_stride_C
,
...
@@ -100,22 +102,38 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
...
@@ -100,22 +102,38 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelArgs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelArgs
kargs
)
const
{
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
/
kargs
.
KBatch
);
const
auto
i_k
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
-
i_batch
*
kargs
.
KBatch
);
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
,
i_k
);
// options
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
const
auto
batch_offset_A
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_A
);
const
auto
batch_offset_A
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_A
);
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
batch_offset_A
;
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
batch_offset_A
+
splitk_batch_offset
.
a_k_split_offset
;
const
auto
batch_stride_B
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_B
);
const
auto
batch_stride_B
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_B
);
const
auto
batch_offset_B
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_B
);
const
auto
batch_offset_B
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_B
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
batch_offset_B
;
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
batch_offset_B
+
splitk_batch_offset
.
b_k_split_offset
;
const
auto
batch_stride_C
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_C
);
const
auto
batch_stride_C
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_C
);
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
)
+
batch_offset_C
;
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
)
+
batch_offset_C
;
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
if
(
kargs
.
KBatch
==
1
)
{
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
this
->
template
RunGemm
<
memory_operation_enum
::
atomic_add
>(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
}
};
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
c881136b
...
@@ -93,6 +93,7 @@ struct GemmKernel
...
@@ -93,6 +93,7 @@ struct GemmKernel
index_t
stride_A
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_B
;
index_t
stride_C
;
index_t
stride_C
;
index_t
KBatch
;
};
};
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
GemmHostArgs
&
hostArgs
)
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
GemmHostArgs
&
hostArgs
)
...
@@ -105,28 +106,72 @@ struct GemmKernel
...
@@ -105,28 +106,72 @@ struct GemmKernel
hostArgs
.
K
,
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
};
hostArgs
.
stride_C
,
hostArgs
.
k_batch
};
}
}
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
// const void* b_ptr,
// void* c_ptr,
// index_t M,
// index_t N,
// index_t K,
// index_t stride_A,
// index_t stride_B,
// index_t stride_C)
// {
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
// }
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
}
struct
SplitKBatchOffset
{
__device__
SplitKBatchOffset
(
const
GemmKernelArgs
&
kargs
,
const
std
::
size_t
k_id
=
blockIdx
.
z
)
{
constexpr
auto
K1
=
TilePartitioner
::
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
const
index_t
K_t
=
kargs
.
KBatch
*
K1
;
const
index_t
KRead
=
(
kargs
.
K
+
K_t
-
1
)
/
K_t
*
K1
;
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
a_k_split_offset
=
k_id
*
KRead
;
}
else
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
k_id
*
KRead
*
kargs
.
stride_A
;
}
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
k_id
*
KRead
*
kargs
.
stride_B
;
}
else
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
b_k_split_offset
=
k_id
*
KRead
;
}
if
(
k_id
<
static_cast
<
uint32_t
>
(
kargs
.
KBatch
-
1
))
{
splitted_k
=
KRead
;
}
else
{
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
KBatch
-
1
);
}
}
index_t
a_k_split_offset
;
index_t
b_k_split_offset
;
index_t
splitted_k
;
};
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
{
{
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
(
!
((
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
)
||
!
(
std
::
is_same_v
<
CDataType
,
fp16_t
>
||
std
::
is_same_v
<
CDataType
,
bf16_t
>
)))
{
if
(
kargs
.
KBatch
!=
1
)
{
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
if
(
kargs
.
K
%
TilePartitioner
::
kK
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
if
(
kargs
.
K
%
TilePartitioner
::
kK
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
...
@@ -198,17 +243,19 @@ struct GemmKernel
...
@@ -198,17 +243,19 @@ struct GemmKernel
return
true
;
return
true
;
}
}
CK_TILE_DEVICE
auto
MakeGemmTensorViews
(
const
ADataType
*
a_ptr
,
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
const
BDataType
*
b_ptr
,
CK_TILE_DEVICE
static
auto
MakeGemmTensorViews
(
const
ADataType
*
a_ptr
,
CDataType
*
c_ptr
,
const
BDataType
*
b_ptr
,
const
GemmKernelArgs
&
kargs
)
const
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
)
{
{
const
auto
&
a_tensor_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
a_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_A
,
1
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -217,7 +264,7 @@ struct GemmKernel
...
@@ -217,7 +264,7 @@ struct GemmKernel
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
a_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
1
,
kargs
.
stride_A
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -229,7 +276,7 @@ struct GemmKernel
...
@@ -229,7 +276,7 @@ struct GemmKernel
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
b_ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
1
,
kargs
.
stride_B
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -238,7 +285,7 @@ struct GemmKernel
...
@@ -238,7 +285,7 @@ struct GemmKernel
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
b_ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_B
,
1
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -248,7 +295,7 @@ struct GemmKernel
...
@@ -248,7 +295,7 @@ struct GemmKernel
const
auto
&
c_tensor_view
=
[
&
]()
{
const
auto
&
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
,
DstInMemOp
>
(
c_ptr
,
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
make_tuple
(
kargs
.
stride_C
,
1
),
...
@@ -257,7 +304,7 @@ struct GemmKernel
...
@@ -257,7 +304,7 @@ struct GemmKernel
}
}
else
else
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
,
DstInMemOp
>
(
c_ptr
,
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
make_tuple
(
1
,
kargs
.
stride_C
),
...
@@ -270,7 +317,7 @@ struct GemmKernel
...
@@ -270,7 +317,7 @@ struct GemmKernel
}
}
template
<
typename
TensorView
>
template
<
typename
TensorView
>
CK_TILE_DEVICE
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
const
CK_TILE_DEVICE
static
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
{
{
const
auto
&
a_pad_view
=
[
&
]()
{
const
auto
&
a_pad_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
views
.
at
(
I0
);
const
auto
&
a_tensor_view
=
views
.
at
(
I0
);
...
@@ -330,8 +377,8 @@ struct GemmKernel
...
@@ -330,8 +377,8 @@ struct GemmKernel
}
}
template
<
typename
PadView
>
template
<
typename
PadView
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
static
auto
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
const
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
{
{
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
a_block_window
=
make_tile_window
(
const
auto
&
a_block_window
=
make_tile_window
(
...
@@ -363,23 +410,27 @@ struct GemmKernel
...
@@ -363,23 +410,27 @@ struct GemmKernel
* @param kargs GEMM kernel arguments
* @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
*/
CK_TILE_DEVICE
void
RunGemm
(
const
ADataType
*
a_ptr
,
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
const
BDataType
*
b_ptr
,
CK_TILE_DEVICE
static
void
RunGemm
(
const
ADataType
*
a_ptr
,
CDataType
*
c_ptr
,
const
BDataType
*
b_ptr
,
const
GemmKernelArgs
&
kargs
,
CDataType
*
c_ptr
,
const
index_t
block_idx_m
,
void
*
smem_ptr
,
const
index_t
block_idx_n
)
const
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
{
{
// Create Gemm tensor views, pad views and tile windows
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
const
auto
&
gemm_tensor_views_tuple
=
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
;
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
// allocate LDS
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
splitk_batch_offset
.
splitted_k
);
// Run GEMM cooperatively by whole workgroup.
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
...
@@ -389,18 +440,43 @@ struct GemmKernel
...
@@ -389,18 +440,43 @@ struct GemmKernel
// Run Epilogue Pipeline
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
((
DstInMemOp
==
memory_operation_enum
::
set
)
||
(
sizeof
(
CDataType
)
>
2
)
||
(
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
))
{
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
);
}
}
}
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
{
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
SplitKBatchOffset
splitk_batch_offset
(
kargs
);
// options
// options
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
ADataType
*
a_ptr
=
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
splitk_batch_offset
.
a_k_split_offset
;
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
splitk_batch_offset
.
b_k_split_offset
;
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
if
(
kargs
.
KBatch
==
1
)
{
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
c881136b
...
@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
GemmPipelineScheduler
Scheduler
>
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
struct
PipelineImpl
:
public
PipelineImplBase
{
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
c881136b
...
@@ -132,6 +132,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -132,6 +132,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
GemmPipelineScheduler
Scheduler
>
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
struct
PipelineImpl
:
public
PipelineImplBase
{
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
c881136b
...
@@ -53,6 +53,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -53,6 +53,8 @@ struct GemmPipelineAGmemBGmemCRegV1
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
AElementFunction
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
c881136b
...
@@ -13,6 +13,8 @@ namespace ck_tile {
...
@@ -13,6 +13,8 @@ namespace ck_tile {
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
{
static
constexpr
bool
TransposeC
=
false
;
#if 0
#if 0
// 2d
// 2d
template <typename Problem>
template <typename Problem>
...
@@ -114,8 +116,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -114,8 +116,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
constexpr
index_t
smem_size
=
smem_size_a
+
smem_size_b
;
smem_size
+=
smem_size_a
+
smem_size_b
;
return
smem_size
;
return
smem_size
;
}
}
...
@@ -485,13 +486,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -485,13 +486,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
TransposeC
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
{
constexpr
bool
TransposeC
=
false
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I2
=
number
<
2
>
{};
constexpr
auto
I2
=
number
<
2
>
{};
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
c881136b
...
@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2
...
@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
AElementFunction
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
c881136b
...
@@ -444,6 +444,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -444,6 +444,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
}
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
TransposeC
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
{
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
c881136b
...
@@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
...
@@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
using
WarpGemmMfmaF16F16F32M4N64K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
using
WarpGemmMfmaF16F16F32M64N4K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M64N4K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
// bf16
// bf16
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
...
@@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
...
@@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M4N64K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
using
WarpGemmMfmaBf16Bf16F32M64N4K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
// fp8
// fp8
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
c881136b
...
@@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma
...
@@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
@@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK
...
@@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
static_assert
(
Impl
::
kAMBlock
==
1
||
Impl
::
kBNBlock
==
1
,
sequence
<>
,
"Multi-block on both M & N directions is not supported"
);
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
CK_TILE_DEVICE
static
constexpr
auto
get_awarp_dstr_encoding
()
sequence
<>
,
{
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
tuple
<
sequence
<
2
,
1
>>
,
{
tuple
<
sequence
<
0
,
0
>>
,
return
tile_distribution_encoding
<
sequence
<
2
>
,
sequence
<>
,
sequence
<
1
>>
;
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// each M blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kBNBlock
>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMBlock
,
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
CK_TILE_DEVICE
static
constexpr
auto
get_bwarp_dstr_encoding
()
sequence
<>
,
{
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
sequence
<
Impl
::
kCNLane
>>
,
{
tuple
<
sequence
<
1
,
2
>>
,
return
tile_distribution_encoding
<
tuple
<
sequence
<
1
,
0
>>
,
sequence
<>
,
sequence
<
1
,
1
>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
0
,
2
>>
;
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
,
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// each N blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kAMBlock
>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
get_cwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kBNBlock
*
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kAMBlock
*
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
}
using
AWarpDstrEncoding
=
decltype
(
get_awarp_dstr_encoding
());
using
BWarpDstrEncoding
=
decltype
(
get_bwarp_dstr_encoding
());
using
CWarpDstrEncoding
=
decltype
(
get_cwarp_dstr_encoding
());
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
template
<
bool
post_nop_
=
false
>
...
@@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
...
@@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
@@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
...
@@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
@@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
...
@@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
static_assert
(
Impl
::
kAMBlock
==
1
||
Impl
::
kBNBlock
==
1
,
sequence
<>
,
"Multi-block on both M & N directions is not supported"
);
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
CK_TILE_DEVICE
static
constexpr
auto
get_awarp_dstr_encoding
()
sequence
<>
,
{
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
tuple
<
sequence
<
2
,
1
>>
,
{
tuple
<
sequence
<
0
,
0
>>
,
return
tile_distribution_encoding
<
sequence
<
2
>
,
sequence
<>
,
sequence
<
1
>>
;
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
,
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// each N blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kAMBlock
>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
CK_TILE_DEVICE
static
constexpr
auto
get_bwarp_dstr_encoding
()
sequence
<>
,
{
tuple
<
sequence
<
Impl
::
kCNLane
>
,
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
{
tuple
<
sequence
<
2
,
1
>>
,
return
tile_distribution_encoding
<
tuple
<
sequence
<
1
,
0
>>
,
sequence
<>
,
sequence
<
2
,
2
>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
0
,
2
>>
;
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// each M blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kBNBlock
>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMBlock
,
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
get_cwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
*
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kAMBlock
*
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
}
using
AWarpDstrEncoding
=
decltype
(
get_awarp_dstr_encoding
());
using
BWarpDstrEncoding
=
decltype
(
get_bwarp_dstr_encoding
());
using
CWarpDstrEncoding
=
decltype
(
get_cwarp_dstr_encoding
());
template
<
bool
post_nop_
=
false
>
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
...
@@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
...
@@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
...
@@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
...
@@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
/
(
Impl
::
kCMLane
*
SFactor
*
Impl
::
kCM1PerLane
),
tuple
<
sequence
<
Impl
::
kAMLane
/
(
Impl
::
kCMLane
*
SFactor
*
Impl
::
kCM1PerLane
),
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment