Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7572a691
Commit
7572a691
authored
Feb 15, 2025
by
coderfeli
Browse files
merge develop
parents
7796fc73
6b6fcd37
Changes
465
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1281 additions
and
490 deletions
+1281
-490
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+71
-7
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+30
-9
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+0
-48
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+158
-54
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+0
-54
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
+0
-105
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
...ock_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
+2
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+11
-10
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+49
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
.../pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
+28
-6
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+39
-13
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+5
-3
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+3
-2
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+8
-5
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+634
-59
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
+52
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+98
-36
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+4
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+84
-56
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
.../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+5
-2
No files found.
Too many changes to show.
To preserve performance only
465 of 465+
files are displayed.
Plain diff
Email patch
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
7572a691
...
...
@@ -20,10 +20,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
...
...
@@ -84,7 +83,7 @@ struct FmhaFwdKernel
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
_SS_
(
TilePartitioner
::
name
)
+
"_"
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
...
...
@@ -867,9 +866,75 @@ struct FmhaFwdKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
ck_tile
::
index_t
hdim_v_
,
bool
has_padded_seqlen_k
=
false
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
,
hdim_v_
);
// has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
if
(
has_padded_seqlen_k
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
FmhaPipeline
::
kN1
));
}
else
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
FmhaPipeline
::
kN1
),
nhead_
,
batch_size_
);
}
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
bool
has_padded_seqlen_k
=
false
;
if
constexpr
(
kIsGroupMode
)
has_padded_seqlen_k
=
(
kargs
.
seqlen_k_ptr
!=
nullptr
);
if
(
has_padded_seqlen_k
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
else
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -885,8 +950,7 @@ struct FmhaFwdKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
);
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
7572a691
...
...
@@ -5,10 +5,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdSplitKVCombineKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
...
...
@@ -235,12 +234,35 @@ struct FmhaFwdSplitKVCombineKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_seqlen_q
,
hdim_v
);
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -256,8 +278,7 @@ struct FmhaFwdSplitKVCombineKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
);
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
deleted
100644 → 0
View file @
7796fc73
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
index_t
kM0_
,
index_t
kN1_
>
struct
FmhaFwdSplitKVCombineTilePartitioner
{
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
7572a691
...
...
@@ -17,10 +17,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdSplitKVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
...
...
@@ -48,10 +47,16 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
Problem
::
kIsPagedKV
;
static
constexpr
bool
kMergeNumHeadGroupsSeqLenQ
=
FmhaPipeline
::
Problem
::
kMergeNumHeadGroupsSeqLenQ
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static_assert
(
!
kMergeNumHeadGroupsSeqLenQ
||
(
kMergeNumHeadGroupsSeqLenQ
&&
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
&&
!
kHasMask
));
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
...
...
@@ -476,13 +481,40 @@ struct FmhaFwdSplitKVKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead_q
,
ck_tile
::
index_t
nhead_kv
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_seqlen_q
,
hdim_v
,
num_splits
);
ck_tile
::
index_t
nhead_
=
kMergeNumHeadGroupsSeqLenQ
?
nhead_kv
:
nhead_q
;
ck_tile
::
index_t
max_seqlen_q_
=
max_seqlen_q
*
(
kMergeNumHeadGroupsSeqLenQ
?
nhead_q
/
nhead_kv
:
1
);
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q_
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
)
*
num_splits
,
nhead_
,
batch_size
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
x
,
kargs
.
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -498,8 +530,7 @@ struct FmhaFwdSplitKVKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
,
kargs
.
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
@@ -542,7 +573,7 @@ struct FmhaFwdSplitKVKernel
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
if
(
kargs
.
seqlen_q
*
(
kMergeNumHeadGroupsSeqLenQ
?
kargs
.
nhead_ratio_qk
:
1
)
<=
i_m0
)
{
return
;
}
...
...
@@ -597,30 +628,60 @@ struct FmhaFwdSplitKVKernel
}
// for simplicity, batch stride we just modify the pointer
const
index_t
i_nhead_k
=
(
kMergeNumHeadGroupsSeqLenQ
?
i_nhead
:
i_nhead
/
kargs
.
nhead_ratio_qk
);
const
QDataType
*
q_ptr
=
reinterpret_cast
<
const
QDataType
*>
(
kargs
.
q_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
(
kMergeNumHeadGroupsSeqLenQ
?
kargs
.
nhead_ratio_qk
:
1
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
const
KDataType
*
k_ptr
=
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
k_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
+
const
KDataType
*
k_ptr
=
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
k_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_k
)
*
kargs
.
nhead_stride_k
+
batch_offset_k
;
const
VDataType
*
v_ptr
=
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
v_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
const
VDataType
*
v_ptr
=
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
v_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_k
)
*
kargs
.
nhead_stride_v
+
batch_offset_v
;
ODataType
*
o_acc_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o_acc
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
(
kMergeNumHeadGroupsSeqLenQ
?
kargs
.
nhead_ratio_qk
:
1
)
*
kargs
.
nhead_stride_o_acc
+
batch_offset_o_acc
+
i_split
*
kargs
.
split_stride_o_acc
;
// Q/K/V DRAM and DRAM window
const
auto
q_dram
=
[
&
]()
{
const
auto
q_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
q_dram
=
[
&
]
{
const
auto
q_dram_naive
=
[
&
]
{
if
constexpr
(
kMergeNumHeadGroupsSeqLenQ
)
{
// reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
// hdim_q)
const
auto
view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
q_ptr
,
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
nhead_stride_q
,
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
return
transform_tensor_view
(
view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
)),
make_pass_through_transform
(
kargs
.
hdim_q
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
q_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
}
}();
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
{
return
pad_tensor_view
(
...
...
@@ -709,7 +770,7 @@ struct FmhaFwdSplitKVKernel
}
}();
auto
k_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
auto
k_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kIsPagedKV
)
{
const
auto
*
block_indices
=
...
...
@@ -719,8 +780,7 @@ struct FmhaFwdSplitKVKernel
integer_divide_ceil
(
kv_l2p_offset
+
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
;
static_cast
<
long_index_t
>
(
i_nhead_k
)
*
kargs
.
nhead_stride_k
;
return
make_page_block_navigator
<
const
KDataType
,
0
>
(
kargs
.
k_ptr
,
...
...
@@ -740,7 +800,7 @@ struct FmhaFwdSplitKVKernel
}
}();
auto
v_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
auto
v_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kIsPagedKV
)
{
const
auto
*
block_indices
=
...
...
@@ -750,8 +810,7 @@ struct FmhaFwdSplitKVKernel
integer_divide_ceil
(
kv_l2p_offset
+
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
;
static_cast
<
long_index_t
>
(
i_nhead_k
)
*
kargs
.
nhead_stride_v
;
return
make_page_block_navigator
<
const
VDataType
,
1
>
(
kargs
.
v_ptr
,
...
...
@@ -822,19 +881,40 @@ struct FmhaFwdSplitKVKernel
// lse acc
auto
lse_acc_dram_window
=
[
&
,
i_nhead_
=
i_nhead
,
i_split_
=
i_split
]()
{
constexpr
auto
lse_acc_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{});
LSEDataType
*
lse_acc_ptr
=
reinterpret_cast
<
LSEDataType
*>
(
kargs
.
lse_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_lse_acc
+
LSEDataType
*
lse_acc_ptr
=
reinterpret_cast
<
LSEDataType
*>
(
kargs
.
lse_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kMergeNumHeadGroupsSeqLenQ
?
kargs
.
nhead_ratio_qk
:
1
)
*
kargs
.
nhead_stride_lse_acc
+
batch_offset_lse_acc
+
i_split_
*
kargs
.
split_stride_lse_acc
;
const
auto
lse_acc_dram
=
[
&
]()
{
const
auto
lse_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
lse_acc_ptr
,
const
auto
lse_acc_dram
=
[
&
]
{
const
auto
lse_acc_dram_naive
=
[
&
]
{
if
constexpr
(
kMergeNumHeadGroupsSeqLenQ
)
{
// reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
const
auto
view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
lse_acc_ptr
,
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
),
make_tuple
(
kargs
.
nhead_stride_lse_acc
,
1
),
number
<
1
>
{},
number
<
1
>
{});
return
transform_tensor_view
(
view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
))),
make_tuple
(
sequence
<
0
,
1
>
{}),
make_tuple
(
sequence
<
0
>
{}));
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
lse_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
),
make_tuple
(
1
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
return
pad_tensor_view
(
lse_acc_dram_naive
,
lse_acc_dram_window_lengths
,
sequence
<
kPadSeqLenQ
>
{});
}();
...
...
@@ -933,13 +1013,37 @@ struct FmhaFwdSplitKVKernel
}();
// Oacc DRAM and Oacc DRAM window
auto
o_acc_dram
=
[
&
]()
{
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
o_acc_dram
=
[
&
]
{
const
auto
o_acc_dram_naive
=
[
&
]
{
if
constexpr
(
kMergeNumHeadGroupsSeqLenQ
)
{
// reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
// hdim_v)
const
auto
view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
nhead_stride_o_acc
,
kargs
.
stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
return
transform_tensor_view
(
view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
)),
make_pass_through_transform
(
kargs
.
hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
}
}();
return
pad_tensor_view
(
o_acc_dram_naive
,
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
deleted
100644 → 0
View file @
7796fc73
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdSplitKVTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
)
*
num_splits
,
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
x
,
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
deleted
100644 → 0
View file @
7796fc73
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"shb"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
template
<
typename
BlockFmhaShape_
>
using
FmhaFwdTilePartitioner_SHB
=
FmhaFwdTilePartitioner
<
BlockFmhaShape_
>
;
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner_HBS
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"hbs"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
));
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
View file @
7572a691
...
...
@@ -343,6 +343,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
// ensure LDS access by Q is done before the over-writting by K
block_sync_lds
();
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
do
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
7572a691
...
...
@@ -103,6 +103,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
Traits
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
static
constexpr
bool
kMergeNumHeadGroupsSeqLenQ
=
Traits
::
kMergeNumHeadGroupsSeqLenQ
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
7572a691
...
...
@@ -5,14 +5,14 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace
ck_tile
{
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
struct
[[
deprecated
]]
BlockFmhaPipelineQSKSVS
struct
BlockFmhaPipelineQSKSVS
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
...
...
@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
...
...
@@ -81,20 +99,18 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static
constexpr
const
char
*
name
=
"qs"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeQ
()
{
return
Policy
::
template
GetSmemSizeQ
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
...
...
@@ -114,6 +130,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
/* unused_randval_dram_block_window_tmp */
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
...
...
@@ -122,7 +139,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
DropoutType
&
/* unused_dropout */
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
@@ -222,11 +240,11 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
{
seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
...
@@ -305,7 +323,6 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
});
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
block_sync_lds
();
gemm_0
(
s_acc
,
q_lds_window
,
k_lds_window
);
...
...
@@ -318,6 +335,10 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
gemm_0
(
s_acc
,
q_lds_window
,
k_lds_window
);
}
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -439,6 +460,12 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
__builtin_amdgcn_sched_barrier
(
0
);
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
...
...
@@ -486,9 +513,6 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
...
...
@@ -583,6 +607,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
...
...
@@ -590,11 +615,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
...
...
@@ -604,6 +631,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
@@ -612,7 +640,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
mask
,
position_encoding
,
scale_s
,
smem_ptr
);
smem_ptr
,
dropout
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp
View file @
7572a691
...
...
@@ -9,11 +9,33 @@
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaPipelineQSKSVSDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
struct
BlockFmhaPipelineQSKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
false
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
/* NumPrefetchV = */
1
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeK
()
{
return
MakeKLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
KDataType
);
}
// namespace ck_tile
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeV
()
{
return
MakeVLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
*
sizeof
(
typename
Problem
::
VDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
max
(
GetSmemSizeQ
<
Problem
>
()
+
GetSmemSizeK
<
Problem
>
(),
GetSmemSizeV
<
Problem
>
())
+
GetSmemSizeDropout
<
Problem
>
();
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
7572a691
...
...
@@ -125,9 +125,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
}
};
/// NOTICE: we no-longer use this policy.
template
<
>
struct
[[
deprecated
]]
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
struct
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
{
static
constexpr
bool
QLoadOnce
=
false
;
...
...
@@ -147,8 +146,16 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentQ
()
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
return
16
/
sizeof
(
QDataType
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
typename
Problem
::
QDataType
);
// this should align with MakeQDramTileDistribution()
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
return
min
(
ElemPerThread
,
MaxVectorSize
);
}
template
<
typename
Problem
>
...
...
@@ -157,19 +164,25 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
16
/
sizeof
(
QDataType
);
// use dwordx4. TODO: change this
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
MaxVectorSize
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
ElemPerThread
=
(
kMPerBlock
*
kKPerBlock
)
/
kBlockSize
;
static_assert
(
0
<
ElemPerThread
);
constexpr
index_t
kMaxVecLoad
=
min
(
ElemPerThread
,
MaxVectorSize
);
constexpr
index_t
KPerThread
=
kMaxVecLoad
;
constexpr
index_t
KThreads
=
kKPerBlock
/
KPerThread
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
KThreads
;
constexpr
index_t
NumWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
MThreadPerWarp
*
NumWarps
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
MPerThread
,
NumWarps
,
MThreadPerWarp
>
,
sequence
<
KThreads
,
KPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
...
...
@@ -216,18 +229,31 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
4
||
WarpGemmM
==
16
||
WarpGemmM
==
32
);
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaF16F16F32M4N64K16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
if
constexpr
(
WarpGemmM
==
16
)
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
else
// WarpGemmM == 4
return
WarpGemmMfmaBf16Bf16F32M4N64K16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
7572a691
...
...
@@ -43,6 +43,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool
kDoFp8StaticQuant_
,
bool
kIsPagedKV_
,
bool
kHasUnevenSplits_
,
bool
kMergeNumHeadGroupsSeqLenQ_
=
false
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdSplitKVTraits
{
...
...
@@ -57,6 +58,7 @@ struct TileFmhaFwdSplitKVTraits
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
// determine if some split (length) is not divisible by tile size
static
constexpr
bool
kHasUnevenSplits
=
kHasUnevenSplits_
;
static
constexpr
bool
kMergeNumHeadGroupsSeqLenQ
=
kMergeNumHeadGroupsSeqLenQ_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
...
...
include/ck_tile/ops/fused_moe.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
...
...
@@ -14,6 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
7572a691
...
...
@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
@@ -111,7 +111,7 @@ struct FusedMoeGemmHostArgs
const
void
*
num_sorted_tiles_ptr
;
// [1]
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n / TP, for Gate
. if Gate+Up, Down need divide by 2
index_t
intermediate_size
;
// n / TP, for Gate
/UP/Down
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
...
...
@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel
return
base_str
;
}();
return
_SS_
(
"fused_moe_"
)
+
_SS_
(
prec_str
)
+
"_"
+
return
_SS_
(
"fused_moe_"
)
+
_SS_
(
prec_str
)
+
"_"
+
(
IsGateOnly
?
"g1u0_"
:
"g1u1_"
)
+
_TS_
(
S_
::
Block_M0
)
+
"x"
+
_TS_
(
S_
::
Block_N0
)
+
"x"
+
_TS_
(
S_
::
Block_K0
)
+
"x"
+
_TS_
(
S_
::
Block_N1
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_K0
)
+
"_"
+
_TS_
(
S_
::
Warp_M0
)
+
"x"
+
_TS_
(
S_
::
Warp_N0
)
+
"x"
+
_TS_
(
S_
::
Warp_K0
)
+
"_"
+
_SS_
(
Pipeline
::
name
);
...
...
@@ -204,7 +204,7 @@ struct FusedMoeGemmKernel
const
void
*
num_sorted_tiles_ptr
;
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n / TP, for Gate
. if Gate+Up, Down need divide by 2
index_t
intermediate_size
;
// n / TP, for Gate
/Up/Down
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
...
...
@@ -239,7 +239,7 @@ struct FusedMoeGemmKernel
{
if
constexpr
(
UseUK
)
{
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
__shared__
CK_TILE_LDS_ADDR
char
smem
[
GetSmemSize
()];
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
...
...
@@ -298,6 +298,9 @@ struct FusedMoeGemmKernel
index_t
token_id
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_token_id
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
token_id
&=
0xffffff
;
#endif
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
7572a691
...
...
@@ -15,6 +15,10 @@ namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
...
...
@@ -28,7 +32,7 @@ namespace ck_tile {
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
@@ -55,6 +59,34 @@ namespace ck_tile {
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens)
// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5]
// num_tokens_post_padded_ptr : [24]
//
// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case)
// and modify the output expert-ID, because we will only have enbaled expert on specific GPU.
// we call expert input to this kernel as "global expert id", output as "local expert id"
//
// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id")
// num_tokens_post_padded_ptr : [20]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
...
...
@@ -67,10 +99,80 @@ namespace ck_tile {
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
CK_TILE_HOST
constexpr
auto
moe_sorting_get_smem_row_col
(
int
num_tokens_
,
int
num_experts_
)
{
/* num_experts + 1
* +--------------------------------------+
* | |
* | |
* | | * -> sub-tokens
* | |
* | |
* +--------------------------------------+
* | | 2 -> cumsum buffer
* +--------------------------------------+
*
*/
int
smem_cols
=
num_experts_
+
1
;
// usually experts is power of 2. padding here
int
smem_rows
=
[
&
](){
index_t
target_occupancy_
=
2
;
constexpr
index_t
total_
=
65536
/
sizeof
(
int
);
constexpr
index_t
sub_unroll
=
8
;
constexpr
index_t
cumsum_bufs
=
2
;
// 1 for cumsum, 1 for cnt
// at lease 2 lines, one for sub_token unroll, one for cumsum
// should be enough
if
((
total_
/
target_occupancy_
)
<
((
cumsum_bufs
+
sub_unroll
)
*
smem_cols
))
{
if
((
total_
/
1
)
<
((
cumsum_bufs
+
sub_unroll
)
*
smem_cols
))
throw
std
::
runtime_error
(
"too many num_experts, can't allocate smem"
);
target_occupancy_
=
1
;
}
int
r
=
total_
/
target_occupancy_
/
smem_cols
;
// round to sub_unroll multipl
int
r_for_sub_token
=
r
-
cumsum_bufs
;
r_for_sub_token
=
min
(
r_for_sub_token
,
num_tokens_
);
r_for_sub_token
=
(
r_for_sub_token
+
sub_unroll
-
1
)
/
sub_unroll
*
sub_unroll
;
r_for_sub_token
=
max
(
r_for_sub_token
,
1
);
if
(
r_for_sub_token
>
1
)
{
int
r_unroll_
=
r_for_sub_token
/
sub_unroll
;
// round to 1x/2x/4x/8x number of sub_unroll
int
clz_
=
__builtin_clz
(
r_unroll_
);
// 0b1:31 0b2:30, 0b3:30, 0b4:29
int
mask_
=
(
1
<<
(
31
-
clz_
))
-
1
;
mask_
=
mask_
>
0b111
?
0b111
:
mask_
;
//clamp to 8x at most
mask_
=
~
mask_
;
//printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout);
r_for_sub_token
=
(
r_unroll_
&
mask_
)
*
sub_unroll
;
}
// final check
if
(
(
r_for_sub_token
+
cumsum_bufs
*
smem_cols
*
target_occupancy_
)
>=
total_
)
{
throw
std
::
runtime_error
(
"can't run this kernel, request LDS over size"
);
}
return
r_for_sub_token
+
cumsum_bufs
;
}();
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
return
ck_tile
::
make_tuple
(
smem_rows
,
smem_cols
);
}
struct
MoeSortingHostArgs
{
const
void
*
p_topk_ids
;
// [token, topk]
const
void
*
p_weights
;
// [token, topk]
const
void
*
p_local_expert_mask
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
...
...
@@ -101,6 +203,7 @@ struct MoeSortingKernel
{
const
void
*
p_topk_ids
;
const
void
*
p_weights
;
const
void
*
p_local_expert_mask
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
...
...
@@ -111,8 +214,11 @@ struct MoeSortingKernel
index_t
moe_buf_bytes
;
index_t
tokens_per_thread
;
index_t
smem_rows
;
mdiv
unit_size_mdiv
;
mdiv
topk_mdiv
;
mdiv
expert_mdiv
;
// mdiv sub_tokens_mdiv;
};
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
...
...
@@ -123,15 +229,25 @@ struct MoeSortingKernel
CK_TILE_HOST
static
constexpr
auto
BlockSize
(
const
Hargs
&
h
)
{
#if MOE_SORTING_USE_EX_KERNEL
(
void
)
h
;
return
dim3
(
256
);
#else
return
dim3
(
ck_tile
::
integer_least_multiple
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
#endif
}
// in byte
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
{
#if MOE_SORTING_USE_EX_KERNEL
auto
[
smem_rows
,
smem_cols
]
=
moe_sorting_get_smem_row_col
(
h
.
tokens
,
h
.
num_experts
);
return
smem_rows
*
smem_cols
*
sizeof
(
int
);
#else
const
auto
blocks
=
BlockSize
(
h
);
// usually num_experts is power of 2, we pad 1 dword here for the row-size
return
((
blocks
.
x
+
1
)
*
(
h
.
num_experts
+
1
)
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
#endif
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
...
...
@@ -139,6 +255,7 @@ struct MoeSortingKernel
Kargs
k
;
k
.
p_topk_ids
=
h
.
p_topk_ids
;
k
.
p_weights
=
h
.
p_weights
;
k
.
p_local_expert_mask
=
h
.
p_local_expert_mask
;
k
.
p_sorted_token_ids
=
h
.
p_sorted_token_ids
;
k
.
p_sorted_weights
=
h
.
p_sorted_weights
;
k
.
p_sorted_expert_ids
=
h
.
p_sorted_expert_ids
;
...
...
@@ -152,10 +269,18 @@ struct MoeSortingKernel
k
.
tokens_per_thread
=
integer_divide_ceil
(
h
.
tokens
*
h
.
topk
,
blocks
.
x
);
k
.
unit_size_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
unit_size
)};
k
.
topk_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
topk
)};
k
.
smem_rows
=
[
&
](){
auto
[
r_
,
c_
]
=
moe_sorting_get_smem_row_col
(
h
.
tokens
,
h
.
num_experts
);
(
void
)
c_
;
return
r_
;
}();
k
.
expert_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
num_experts
)};
// k.sub_tokens_mdiv = mdiv{static_cast<uint32_t>(k.smem_rows - 1)};
return
k
;
}
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
// NOTE: wave_size need at least be 16!! dpp 16 is one row
template
<
typename
data_t
,
int
wave_size
>
__device__
inline
void
wave_cumsum
(
data_t
&
thread_data
)
const
{
...
...
@@ -196,6 +321,40 @@ struct MoeSortingKernel
bank_mask
,
bound_ctrl
)));
// row_shr:4
}
if
constexpr
(
wave_size
==
8
)
{
// wave-size=8 need one extra shift
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x118
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:8
#if 0
constexpr int bank_mask_0_7 = 0b1100;
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
__builtin_amdgcn_update_dpp(0, /* old value */
__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask_0_7,
bound_ctrl))// row_newbcast:7
);
#else
data_t
xxx
=
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x157
,
row_mask
,
bank_mask
,
bound_ctrl
));
// row_newbcast:7
data_t
yyy
=
(
__lane_id
()
/
8
)
%
2
==
0
?
0
:
xxx
;
thread_data
=
thread_data
-
yyy
;
#endif
}
if
constexpr
(
wave_size
>
8
)
{
thread_data
=
...
...
@@ -224,6 +383,36 @@ struct MoeSortingKernel
}
}
// reduce single pixel within a wave
template
<
typename
T
,
typename
F
,
index_t
wave_size_
=
warpSize
>
__device__
static
constexpr
T
wave_reduce
(
T
local
,
F
reduce_f
,
number
<
wave_size_
>
=
{})
{
// constexpr int wave_size = 64;
// constexpr int reduce_stage = 6; // 1<<6=64
// clang-format off
constexpr
int
reduce_stage
=
[](){
if
constexpr
(
wave_size_
==
2
)
return
1
;
else
if
constexpr
(
wave_size_
==
4
)
return
2
;
else
if
constexpr
(
wave_size_
==
8
)
return
3
;
else
if
constexpr
(
wave_size_
==
16
)
return
4
;
else
if
constexpr
(
wave_size_
==
32
)
return
5
;
else
if
constexpr
(
wave_size_
==
64
)
return
6
;
else
return
0
;
}();
// clang-format on
T
v_local
=
local
;
#pragma unroll reduce_stage
for
(
int
i_stage
=
0
;
i_stage
<
reduce_stage
;
i_stage
++
)
{
int
src_lane
=
__lane_id
()
^
(
1
<<
i_stage
);
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
v_local
));
T
v_remote
=
bit_cast
<
T
>
(
v_remote_tmp
);
v_local
=
reduce_f
(
v_local
,
v_remote
);
}
return
v_local
;
}
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
{
return
row
*
total_col
+
col
;
...
...
@@ -257,37 +446,37 @@ struct MoeSortingKernel
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
(
num_experts
+
1
);
// 1: (num_experts + 1)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
(
num_experts
+
1
);
// 1: (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
)]
=
0
;
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
topk_id
[
i
])];
++
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
topk_id
[
i
])];
}
__syncthreads
();
#if 1
if
(
tid
<
num_experts
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
index_t
local_c
[
8
];
index_t
prev_c
=
0
;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
i
+=
8
)
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
i
+=
8
)
{
local_c
[
0
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)];
local_c
[
1
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)];
local_c
[
2
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)];
local_c
[
3
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)];
local_c
[
4
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)];
local_c
[
5
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)];
local_c
[
6
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)];
local_c
[
7
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)];
local_c
[
0
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)];
local_c
[
1
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)];
local_c
[
2
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)];
local_c
[
3
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)];
local_c
[
4
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)];
local_c
[
5
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)];
local_c
[
6
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)];
local_c
[
7
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)];
local_c
[
0
]
+=
prev_c
;
local_c
[
1
]
+=
local_c
[
0
];
...
...
@@ -299,50 +488,56 @@ struct MoeSortingKernel
local_c
[
7
]
+=
local_c
[
6
];
prev_c
=
local_c
[
7
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)]
=
local_c
[
0
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)]
=
local_c
[
1
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)]
=
local_c
[
2
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)]
=
local_c
[
3
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)]
=
local_c
[
4
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)]
=
local_c
[
5
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)]
=
local_c
[
6
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)]
=
local_c
[
7
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)]
=
local_c
[
0
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)]
=
local_c
[
1
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)]
=
local_c
[
2
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)]
=
local_c
[
3
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)]
=
local_c
[
4
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)]
=
local_c
[
5
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)]
=
local_c
[
6
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)]
=
local_c
[
7
];
}
}
#else
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future
// heuristic
{
if
(
tid
<
num_experts
)
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
i
+=
8
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
i
+=
8
)
{
index_t
local_c
[
8
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
local_c
[
j
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
local_c
[
j
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)];
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
wave_cumsum
<
int
,
64
>
(
local_c
[
j
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)]
=
local_c
[
j
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)]
=
local_c
[
j
];
}
}
}
#endif
__syncthreads
();
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
(
tid
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
auto
current_units
=
[
&
]()
{
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
i
-
1
)]
+
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
...
...
@@ -351,20 +546,24 @@ struct MoeSortingKernel
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
// for simplicity, not check experts here.
int
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >=
// expert) for simplicity, not check experts here.
int
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
int
blocks_pers_expert
=
unit_size_mdiv
.
div
(
local_cnt
+
unit_size_mdiv
.
divisor
-
1
);
int
padded_tokens_per_expert
=
max
(
blocks_pers_expert
,
1
)
*
unit_size_mdiv
.
divisor
;
int
local_cumsum
=
padded_tokens_per_expert
;
wave_cumsum
<
int
,
64
>
(
local_cumsum
);
if
(
tid
==
(
num_experts
-
1
))
{
if
(
tid
==
(
num_experts
-
1
))
{
cumsum
[
0
]
=
0
;
*
p_total_tokens_post_pad
=
local_cumsum
;
}
if
(
tid
<
num_experts
)
{
if
(
tid
<
num_experts
)
{
cumsum
[
tid
+
1
]
=
local_cumsum
;
}
}
...
...
@@ -384,7 +583,7 @@ struct MoeSortingKernel
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
index_t
expert_id
=
topk_id
[
i
];
index_t
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)];
index_t
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)];
index_t
rank_post_pad
=
local_cnt
+
cumsum
[
expert_id
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t
curr_token_id
,
curr_topk_id
;
...
...
@@ -394,15 +593,16 @@ struct MoeSortingKernel
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
#endif
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)]
=
local_cnt
+
1
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)]
=
local_cnt
+
1
;
}
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
(
tid
<
num_experts
)
{
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
index_t
expert_end
=
cumsum
[
tid
+
1
];
while
(
expert_offset
<
expert_end
)
{
...
...
@@ -417,16 +617,19 @@ struct MoeSortingKernel
}
}
}
else
{
else
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
// TODO: only support expert-tile like 8, 16, 32
static
constexpr
index_t
experts_per_wave
=
warpSize
/
Problem
::
ExpertTile
;
{
index_t
eid
=
tid
/
experts_per_wave
;
index_t
expert_offset
=
cumsum
[
eid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
eid
)]
+
tid
%
experts_per_wave
;
index_t
expert_offset
=
cumsum
[
eid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
eid
)]
+
tid
%
experts_per_wave
;
index_t
expert_end
=
cumsum
[
eid
+
1
];
if
(
eid
<
num_experts
)
{
if
(
eid
<
num_experts
)
{
while
(
expert_offset
<
expert_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
...
...
@@ -436,9 +639,362 @@ struct MoeSortingKernel
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
#endif
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
+=
experts_per_wave
;
expert_offset
+=
experts_per_wave
;
}
}
}
}
}
// only support index_t, and single pixel access
struct
simple_smem_indexer
{
index_t
*
smem
;
index_t
row_stride
;
// this is 2D
CK_TILE_DEVICE
simple_smem_indexer
(
index_t
*
smem_
,
index_t
row_stride_
)
:
smem
(
smem_
),
row_stride
(
row_stride_
)
{
}
CK_TILE_DEVICE
const
index_t
&
operator
()(
index_t
i_row
,
index_t
i_col
)
const
{
return
smem
[
i_row
*
row_stride
+
i_col
];
}
CK_TILE_DEVICE
index_t
&
operator
()(
index_t
i_row
,
index_t
i_col
)
{
return
smem
[
i_row
*
row_stride
+
i_col
];
}
// this is 1D or linear
CK_TILE_DEVICE
simple_smem_indexer
(
index_t
*
smem_
)
:
smem
(
smem_
),
row_stride
(
0
)
{}
CK_TILE_DEVICE
const
index_t
&
operator
()(
index_t
idx
)
const
{
return
smem
[
idx
];
}
CK_TILE_DEVICE
index_t
&
operator
()(
index_t
idx
)
{
return
smem
[
idx
];
}
};
CK_TILE_DEVICE
void
moe_align_block_size_kernel_ex
(
const
IndexType
*
__restrict__
topk_id
,
const
WeightType
*
__restrict__
weights
,
const
IndexType
*
__restrict__
local_expert_mask
,
index_t
*
p_sorted_token_ids
,
WeightType
*
p_sorted_weights
,
index_t
*
p_sorted_expert_ids
,
index_t
*
p_total_tokens_post_pad
,
const
index_t
num_experts
,
const
index_t
tokens
,
const
mdiv
unit_size_mdiv
,
const
mdiv
topk_mdiv
,
const
mdiv
expert_mdiv
,
const
index_t
smem_rows
,
void
*
smem
)
const
{
const
index_t
tid
=
static_cast
<
index_t
>
(
threadIdx
.
x
);
const
index_t
wid
=
__builtin_amdgcn_readfirstlane
(
tid
/
warpSize
);
const
index_t
lid
=
__lane_id
();
constexpr
index_t
block_size
=
256
;
// blockDim.x;
const
index_t
sub_tokens
=
smem_rows
-
2
;
// sub_tokens_mdiv.divisor;
const
index_t
topk
=
topk_mdiv
.
divisor
;
auto
f_sum
=
[](
auto
x_
,
auto
y_
)
{
return
x_
+
y_
;
};
const
index_t
smem_cols
=
num_experts
+
1
;
simple_smem_indexer
smem_cumsum
{
reinterpret_cast
<
index_t
*>
(
smem
)
+
0
};
simple_smem_indexer
smem_cumdup
{
reinterpret_cast
<
index_t
*>
(
smem
)
+
smem_cols
};
simple_smem_indexer
smem_tokens
{
reinterpret_cast
<
index_t
*>
(
smem
)
+
2
*
smem_cols
,
smem_cols
};
// #pragma unroll 8
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
num_experts
);
i
+=
block_size
)
{
uint32_t
curr_token_id
,
curr_expert_id
;
expert_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_expert_id
);
smem_tokens
(
curr_token_id
,
curr_expert_id
)
=
0
;
}
__syncthreads
();
for
(
int
i_token
=
0
;
i_token
<
tokens
;
i_token
+=
sub_tokens
)
{
// NOTE: below for loop can't have barrier inside!!
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
topk
);
i
+=
block_size
)
{
uint32_t
curr_token_id
,
curr_topk_id
;
topk_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_topk_id
);
int
i_t
=
i_token
+
curr_token_id
;
if
(
i_t
<
tokens
)
{
int
eid
=
topk_id
[
i_t
*
topk
+
curr_topk_id
];
if
constexpr
(
Problem
::
SubTokenOneShot
)
smem_tokens
(
curr_token_id
,
eid
)
=
curr_topk_id
+
1
;
else
smem_tokens
(
curr_token_id
,
eid
)
++
;
}
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
}
__syncthreads
();
// make sure different i_token iteration not overlap by different wave
}
// counting
if
(
tid
==
0
)
{
smem_cumsum
(
0
)
=
0
;
// smem_cumdup(0) = 0;
}
{
constexpr
int
lane_group_sz
=
8
;
int
lane_group_id
=
tid
/
lane_group_sz
;
int
lane_group_os
=
tid
%
lane_group_sz
;
constexpr
int
lane_group_nm
=
block_size
/
lane_group_sz
;
for
(
int
i_e
=
lane_group_id
;
i_e
<
num_experts
;
i_e
+=
lane_group_nm
)
{
index_t
local_c
[
Problem
::
SubTokenTile
];
index_t
cnt
=
0
;
for
(
int
i
=
0
;
i
<
sub_tokens
;
i
+=
8
*
Problem
::
SubTokenTile
)
{
#pragma unroll Problem::SubTokenTile
for
(
int
j
=
0
;
j
<
Problem
::
SubTokenTile
;
j
++
)
{
local_c
[
j
]
=
smem_tokens
(
i
+
j
*
8
+
lane_group_os
,
i_e
);
if
constexpr
(
Problem
::
SubTokenOneShot
)
{
local_c
[
j
]
=
local_c
[
j
]
!=
0
?
1
:
0
;
}
}
#pragma unroll Problem::SubTokenTile
for
(
int
j
=
0
;
j
<
Problem
::
SubTokenTile
;
j
++
)
{
cnt
+=
wave_reduce
(
local_c
[
j
],
f_sum
,
number
<
8
>
{});
}
}
if
(
lane_group_os
==
0
)
smem_cumsum
(
i_e
+
1
)
=
cnt
;
}
}
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
smem_cumdup
(
0
)
=
0
;
for
(
int
i_e
=
tid
;
i_e
<
num_experts
;
i_e
+=
block_size
)
{
// reuse this buffer
smem_cumdup
(
i_e
+
1
)
=
local_expert_mask
[
i_e
];
}
}
__syncthreads
();
{
if
(
wid
==
0
)
{
// NOTE: under this block can never use __syncthreads!
int
i_e_
=
0
;
int
local_cumsum_
=
0
;
for
(;
i_e_
<
num_experts
;
i_e_
+=
warpSize
)
{
int
pre_cumsum_
=
smem_cumsum
(
lid
==
0
?
i_e_
:
0
);
int
local_cnt
=
smem_cumsum
(
i_e_
+
lid
+
1
);
int
blocks_pers_expert
=
unit_size_mdiv
.
div
(
local_cnt
+
unit_size_mdiv
.
divisor
-
1
);
int
pre_cumsum_masking
=
[
&
]()
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
return
smem_cumdup
(
lid
==
0
?
i_e_
:
0
);
else
return
0
;
// not used
}();
int
local_masking
=
[
&
]()
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
return
smem_cumdup
(
i_e_
+
lid
+
1
);
else
return
0
;
// not used
}();
int
padded_tokens_per_expert
=
[
&
]()
{
int
x_
=
[
&
]()
{
if
constexpr
(
Problem
::
SkipExpertsWithZeroTokens
)
{
// if local_cnt is zero, blocks_pers_expert will be zero
// this is what we want to achieve
return
blocks_pers_expert
*
unit_size_mdiv
.
divisor
;
}
else
{
return
max
(
blocks_pers_expert
,
1
)
*
unit_size_mdiv
.
divisor
;
}
}();
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
return
local_masking
?
x_
:
0
;
}
else
return
x_
;
}();
local_cumsum_
=
padded_tokens_per_expert
;
local_cumsum_
+=
pre_cumsum_
;
// note pre_cumsum must be added after local
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum
<
int
,
warpSize
>
(
local_cumsum_
);
if
((
i_e_
+
lid
)
<
num_experts
)
smem_cumsum
(
i_e_
+
lid
+
1
)
=
local_cumsum_
;
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
local_masking
+=
pre_cumsum_masking
;
wave_cumsum
<
int
,
warpSize
>
(
local_masking
);
if
((
i_e_
+
lid
)
<
num_experts
)
smem_cumdup
(
i_e_
+
lid
+
1
)
=
local_masking
;
}
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
// for above write however __syncthreads will cause barrier with waves other
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
}
if
((
lid
+
i_e_
-
warpSize
)
==
(
num_experts
-
1
))
{
*
p_total_tokens_post_pad
=
local_cumsum_
;
}
}
__syncthreads
();
}
for
(
int
i_e
=
tid
;
i_e
<
num_experts
;
i_e
+=
block_size
)
{
int
e_start
=
smem_cumsum
(
i_e
);
int
e_end
=
smem_cumsum
(
i_e
+
1
);
int
expert_id
=
[
&
]()
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
// local expert id from cumsum
return
smem_cumdup
(
i_e
);
}
else
return
i_e
;
}();
smem_cumdup
(
i_e
)
=
e_start
;
// duplicate cumsum for later use
if
constexpr
(
Problem
::
SkipExpertsWithZeroTokens
)
{
if
(
e_start
==
e_end
)
// skip zero token expert
continue
;
}
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
if
(
local_expert_mask
[
i_e
]
==
0
)
continue
;
}
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
expert_id
;
}
}
smem_cumdup
(
num_experts
)
=
smem_cumsum
(
num_experts
);
// fill the p_sorted_token_ids/p_sorted_weights
for
(
int
i_token
=
0
;
i_token
<
tokens
;
i_token
+=
sub_tokens
)
{
if
constexpr
(
!
Problem
::
SubTokenOneShot
)
{
// clear every time
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
num_experts
);
i
+=
block_size
)
{
uint32_t
curr_token_id
,
curr_expert_id
;
expert_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_expert_id
);
smem_tokens
(
curr_token_id
,
curr_expert_id
)
=
0
;
}
__syncthreads
();
// load again
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
topk
);
i
+=
block_size
)
{
uint32_t
curr_token_id_
,
curr_topk_id_
;
topk_mdiv
.
divmod
(
i
,
curr_token_id_
,
curr_topk_id_
);
int
curr_token_id
=
static_cast
<
int
>
(
curr_token_id_
);
int
curr_topk_id
=
static_cast
<
int
>
(
curr_topk_id_
);
int
i_t
=
i_token
+
curr_token_id
;
if
(
i_t
<
tokens
)
{
int
eid
=
topk_id
[
i_t
*
topk
+
curr_topk_id
];
smem_tokens
(
curr_token_id
,
eid
)
=
curr_topk_id
+
1
;
// at least 1
}
}
__syncthreads
();
}
{
constexpr
int
lane_group_sz
=
8
;
int
lane_group_id
=
tid
/
lane_group_sz
;
int
lane_group_os
=
tid
%
lane_group_sz
;
constexpr
int
lane_group_nm
=
block_size
/
lane_group_sz
;
for
(
int
eid
=
lane_group_id
;
eid
<
num_experts
;
eid
+=
lane_group_nm
)
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
if
(
local_expert_mask
[
eid
]
==
0
)
continue
;
}
int
position
=
smem_cumsum
(
eid
);
for
(
int
i_sub_token
=
lane_group_os
;
i_sub_token
<
sub_tokens
;
i_sub_token
+=
lane_group_sz
)
{
auto
x
=
smem_tokens
(
i_sub_token
,
eid
);
int
local_cnt_cache
=
x
!=
0
?
1
:
0
;
int
local_cnt
=
local_cnt_cache
;
wave_cumsum
<
int
,
lane_group_sz
>
(
local_cnt
);
if
(
x
!=
0
)
{
// now x is topk value
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
position
+
local_cnt
-
1
]
=
MOE_SORTING_MOCK_ID
(
i_token
+
i_sub_token
,
x
-
1
);
#else
p_sorted_token_ids
[
position
+
local_cnt
-
1
]
=
i_token
+
i_sub_token
;
#endif
p_sorted_weights
[
position
+
local_cnt
-
1
]
=
weights
[(
i_token
+
i_sub_token
)
*
topk
+
x
-
1
];
}
int
remote_cnt
=
__builtin_amdgcn_ds_bpermute
(
(
lane_group_sz
*
(
lane_group_id
+
1
)
-
1
)
<<
2
,
local_cnt
);
position
+=
remote_cnt
;
}
smem_cumsum
(
eid
)
=
position
;
}
}
__syncthreads
();
}
// add the skip number
for
(
int
eid
=
tid
;
eid
<
num_experts
;
eid
+=
block_size
)
{
int
e_start
=
smem_cumsum
(
eid
);
int
e_end
=
smem_cumdup
(
eid
+
1
);
if
constexpr
(
Problem
::
SkipExpertsWithZeroTokens
)
{
if
(
e_start
==
e_end
)
// skip zero token expert
continue
;
}
while
(
e_start
<
e_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
e_start
]
=
MOE_SORTING_MOCK_ID
(
tokens
,
topk
);
#else
p_sorted_token_ids
[
e_start
]
=
tokens
;
#endif
p_sorted_weights
[
e_start
]
=
static_cast
<
WeightType
>
(
0.0
);
e_start
++
;
}
}
}
...
...
@@ -456,6 +1012,24 @@ struct MoeSortingKernel
}
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk_mdiv
.
divisor
;
extern
__shared__
char
smem
[];
#if MOE_SORTING_USE_EX_KERNEL
(
void
)
numel
;
return
moe_align_block_size_kernel_ex
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
const
IndexType
*>
(
kargs
.
p_local_expert_mask
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_token_ids
),
static_cast
<
WeightType
*>
(
kargs
.
p_sorted_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_expert_ids
),
static_cast
<
IndexType
*>
(
kargs
.
p_total_tokens_post_pad
),
kargs
.
num_experts
,
kargs
.
tokens
,
kargs
.
unit_size_mdiv
,
kargs
.
topk_mdiv
,
kargs
.
expert_mdiv
,
kargs
.
smem_rows
,
smem
);
#else
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_token_ids
),
...
...
@@ -468,6 +1042,7 @@ struct MoeSortingKernel
kargs
.
unit_size_mdiv
,
kargs
.
topk_mdiv
,
smem
);
#endif
}
};
...
...
include/ck_tile/ops/fused_moe/
pipeli
ne/moe_sorting_problem.hpp
→
include/ck_tile/ops/fused_moe/
ker
ne
l
/moe_sorting_problem.hpp
View file @
7572a691
...
...
@@ -25,4 +25,28 @@ struct MoeSortingProblem
InternalLoadUnroll_
;
// TODO: need better design(like tile size)
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
SubTokenTile_
,
// 1,2,4,8, or 0 in the future
bool
SubTokenOneShot_
,
// if we only loop over once or not
bool
LocalExpertMasking_
,
// used in EP case
bool
SkipExpertsWithZeroTokens_
=
true
,
index_t
ExpertTile_
=
0
>
struct
MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
SubTokenTile
=
SubTokenTile_
;
static
constexpr
bool
SubTokenOneShot
=
SubTokenOneShot_
;
static
constexpr
bool
LocalExpertMasking
=
LocalExpertMasking_
;
static
constexpr
bool
SkipExpertsWithZeroTokens
=
SkipExpertsWithZeroTokens_
;
static_assert
(
SubTokenTile
==
1
||
SubTokenTile
==
2
||
SubTokenTile
==
4
||
SubTokenTile
==
8
);
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
7572a691
...
...
@@ -70,11 +70,16 @@ struct FusedMoeGemmPipeline_FlatmmUk
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
#if 1
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
return
max
(
smem_0
+
smem_1
,
smem_bridge
);
#else
// keep it here purposely in case we have regression
return
65536
;
#endif
}
// this is the thread-offset along row/col
...
...
@@ -125,6 +130,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
array
<
index_t
,
n_size
>
row_ids
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
row_ids
.
at
(
i
)
=
sorted_token_ids_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
row_ids
.
at
(
i
)
&=
0xffffff
;
#endif
});
return
row_ids
;
...
...
@@ -165,8 +173,11 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t
intermediate_tile_id
)
{
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
;
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
*
hidden_radio_0
;
// total gate+up
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
;
// after weight shuffling, gate-only: [nr0, kr0, w0], gate+up: [nr0_gate + nr0_up, kr0, w0]
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Warp_K0
;
// divide K in W
...
...
@@ -200,29 +211,35 @@ struct FusedMoeGemmPipeline_FlatmmUk
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
auto
g_win
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr0
*
kr_0
*
BlockShape
::
Block_W0
;
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
auto
make_gu_win
=
[
&
](
const
auto
*
ptr_
)
{
auto
view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
ptr_
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
1
>
{});
auto
g_
win
dow
_
=
make_tile_window_linear_raw
(
g_
view_
,
auto
win_
=
make_tile_window_linear_raw
(
view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
g_window_
;
}();
return
win_
;
};
const
GDataType
*
gu_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr0
*
kr_0
*
BlockShape
::
Block_W0
;
auto
g_win
=
make_gu_win
(
gu_ptr
);
// Note: gu swizzled, [nr_u+nr_g, kr, w], hence base offset to up is just interm*hidden
auto
u_win
=
make_gu_win
(
gu_ptr
+
kargs
.
intermediate_size
*
kargs
.
hidden_size
);
auto
g_res
=
g_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
u_res
=
u_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
g_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
g_win
.
cached_coords_
[
i
].
get_offset
();
},
number
<
decltype
(
g_win
)
::
NumAccess_NonLinear
>
{});
...
...
@@ -310,6 +327,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
y_pre
=
[
&
]()
{
if
constexpr
(
IsGateOnly
)
{
auto
acc_0
=
uk_0
(
a_res
,
a_coords
,
g_res
,
...
...
@@ -330,7 +351,48 @@ struct FusedMoeGemmPipeline_FlatmmUk
},
sequence
<
1
,
2
>
{});
auto
y_pre
=
cast_tile
<
YDataType
>
(
acc_0
);
return
cast_tile
<
YDataType
>
(
acc_0
);
}
else
{
uint32x8_t
gu_res
;
gu_res
[
0
]
=
g_res
[
0
];
gu_res
[
1
]
=
g_res
[
1
];
gu_res
[
2
]
=
g_res
[
2
];
gu_res
[
3
]
=
g_res
[
3
];
gu_res
[
4
]
=
u_res
[
0
];
gu_res
[
5
]
=
u_res
[
1
];
gu_res
[
6
]
=
u_res
[
2
];
gu_res
[
7
]
=
u_res
[
3
];
auto
acc_0
=
uk_0
(
a_res
,
a_coords
,
gu_res
,
g_coords
,
smem
,
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
,
bool_constant
<
true
>
{});
// tile offset for B matrix each unroll
sweep_tile
(
acc_0
.
at
(
number
<
0
>
{}),
[
&
](
auto
idx0
,
auto
idx1
)
{
fp32x2_t
v_
{
acc_0
.
at
(
number
<
0
>
{})(
idx0
),
acc_0
.
at
(
number
<
0
>
{})(
idx1
)};
typename
Problem
::
GateActivation
{}(
v_
,
v_
);
acc_0
.
at
(
number
<
0
>
{})(
idx0
)
=
v_
.
x
;
acc_0
.
at
(
number
<
0
>
{})(
idx1
)
=
v_
.
y
;
},
sequence
<
1
,
2
>
{});
auto
reduced_acc_0
=
tile_elementwise_in
([
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
return
a_
*
b_
;
},
acc_0
.
at
(
number
<
0
>
{}),
acc_0
.
at
(
number
<
1
>
{}));
return
cast_tile
<
YDataType
>
(
reduced_acc_0
);
}
}();
block_sync_lds
();
...
...
include/ck_tile/ops/gemm.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -29,6 +29,8 @@
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
...
...
@@ -46,3 +48,4 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
View file @
7572a691
...
...
@@ -14,8 +14,12 @@ namespace ck_tile {
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBRegCRegV1DefaultPolicy
>
struct
BlockGemmARegBRegCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
private:
template
<
typename
PipelineProblem_
,
typename
GemmPolicy_
>
struct
GemmTraits_
{
using
Problem
=
remove_cvref_t
<
PipelineProblem_
>
;
using
Policy
=
remove_cvref_t
<
GemmPolicy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
...
...
@@ -23,33 +27,44 @@ struct BlockGemmARegBRegCRegV1
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WarpGemm
::
kK
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
};
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
public:
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Traits
=
GemmTraits_
<
Problem
,
Policy
>
;
using
WarpGemm
=
typename
Traits
::
WarpGemm
;
using
BlockGemmShape
=
typename
Traits
::
BlockGemmShape
;
using
ADataType
=
remove_cvref_t
<
typename
Traits
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Traits
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Traits
::
CDataType
>
;
static
constexpr
index_t
KIterPerWarp
=
Traits
::
KIterPerWarp
;
static
constexpr
index_t
MIterPerWarp
=
Traits
::
MIterPerWarp
;
static
constexpr
index_t
NIterPerWarp
=
Traits
::
NIterPerWarp
;
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
static
constexpr
index_t
MWarp
=
Traits
::
MWarp
;
static
constexpr
index_t
NWarp
=
Traits
::
NWarp
;
// M->N Warp
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistributionEncode
()
{
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
...
...
@@ -57,7 +72,14 @@ struct BlockGemmARegBRegCRegV1
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
return
a_block_dstr_encode
;
}
CK_TILE_DEVICE
static
constexpr
auto
MakeBBlockDistributionEncode
()
{
constexpr
auto
b_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
...
...
@@ -65,7 +87,14 @@ struct BlockGemmARegBRegCRegV1
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
return
b_block_dstr_encode
;
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockDistributionEncode
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
...
...
@@ -73,15 +102,28 @@ struct BlockGemmARegBRegCRegV1
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
a
_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
return
c
_block_dstr_encode
;
}
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr_encode
=
MakeABlockDistributionEncode
();
constexpr
auto
b_block_dstr_encode
=
MakeBBlockDistributionEncode
();
constexpr
auto
c_block_dstr_encode
=
MakeCBlockDistributionEncode
();
// check ABC-block-distribution
static_assert
(
...
...
@@ -100,13 +142,13 @@ struct BlockGemmARegBRegCRegV1
.
get_static_tile_distribution_encoding
())
>>
,
"C distribution is wrong!"
);
using
AWarpDstr
=
typename
W
G
::
AWarpDstr
;
using
BWarpDstr
=
typename
W
G
::
BWarpDstr
;
using
CWarpDstr
=
typename
W
G
::
CWarpDstr
;
using
AWarpDstr
=
typename
W
arpGemm
::
AWarpDstr
;
using
BWarpDstr
=
typename
W
arpGemm
::
BWarpDstr
;
using
CWarpDstr
=
typename
W
arpGemm
::
CWarpDstr
;
using
AWarpTensor
=
typename
W
G
::
AWarpTensor
;
using
BWarpTensor
=
typename
W
G
::
BWarpTensor
;
using
CWarpTensor
=
typename
W
G
::
CWarpTensor
;
using
AWarpTensor
=
typename
W
arpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
W
arpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
W
arpGemm
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
...
...
@@ -145,7 +187,7 @@ struct BlockGemmARegBRegCRegV1
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
W
G
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
W
arpGemm
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
...
...
@@ -159,20 +201,6 @@ struct BlockGemmARegBRegCRegV1
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
...
...
@@ -182,7 +210,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
W
G
::
CWarpDstrEncoding
{});
c_block_outer_dstr_encoding
,
typename
W
arpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
...
...
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
View file @
7572a691
...
...
@@ -79,8 +79,11 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ??
static
constexpr
index_t
InterWaveSchedulingMacClusters
=
1
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
WarpGemm
::
kK
*
KPack
;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static
constexpr
index_t
KPack
=
8
;
static
constexpr
index_t
KPerThread
=
KIterPerWarp
*
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
};
...
...
Prev
1
…
16
17
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment