Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
199f7f71
Commit
199f7f71
authored
Sep 01, 2024
by
carlushuang
Browse files
modify moe
parent
33ceea62
Changes
22
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2452 additions
and
362 deletions
+2452
-362
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_permute_enum.hpp
...tile/05_moe/fused_moe/pipeline/fused_moe_permute_enum.hpp
+15
-0
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
.../ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
+292
-78
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
...e/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
+130
-250
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp
...include/ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp
+410
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp
..._tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp
+27
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2.hpp
...ile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2.hpp
+464
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2_policy.hpp
.../fused_moe/pipeline/fused_moe_pipeline_nsplit2_policy.hpp
+668
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp
...ile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp
+1
-11
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_tile_shape.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moe_tile_shape.hpp
+116
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp
...clude/ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp
+23
-0
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp
.../ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp
+15
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+115
-12
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+14
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+25
-0
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+20
-5
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+26
-0
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+17
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+68
-0
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
...peline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
+5
-5
No files found.
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_
traits
.hpp
→
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_
permute_enum
.hpp
View file @
199f7f71
...
@@ -3,16 +3,8 @@
...
@@ -3,16 +3,8 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
enum
class
FusedMoeWeightPermuteEnum
enum
class
FusedMoePermuteStyle
{
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
...
@@ -20,14 +12,4 @@ enum class FusedMoePermuteStyle
...
@@ -20,14 +12,4 @@ enum class FusedMoePermuteStyle
permute_b_nr_kr_waveflatten
=
permute_b_nr_kr_kw_nw_kv
,
permute_b_nr_kr_waveflatten
=
permute_b_nr_kr_kw_nw_kv
,
no_permute
=
999
,
no_permute
=
999
,
};
};
}
template
<
bool
DownPreShuffled_
=
false
,
FusedMoePermuteStyle
PermuteStyle_
=
FusedMoePermuteStyle
::
permute_b_nr_kr_kw_nw_kv
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
FusedMoeTraits
{
static
constexpr
bool
DownPreShuffled
=
DownPreShuffled_
;
static
constexpr
FusedMoePermuteStyle
PermuteStyle
=
PermuteStyle_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
}
// namespace ck_tile
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
View file @
199f7f71
...
@@ -14,7 +14,7 @@ namespace ck_tile {
...
@@ -14,7 +14,7 @@ namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
>
struct
BlockFmhaPipelineQRKSVSAsync
struct
FusedMoePipeline
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
...
@@ -27,43 +27,49 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -27,43 +27,49 @@ struct BlockFmhaPipelineQRKSVSAsync
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
ScaleDataType
>
;
using
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
ScaleDataType
>
;
using
FusedMoeTileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
using
FusedMoeTileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
FusedMoeTileShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
index_t
kBlockM_0
=
FusedMoeTileShape
::
kBlockM_0
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
static
constexpr
index_t
kBlockN_0
=
FusedMoeTileShape
::
kBlockN_0
;
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static
constexpr
index_t
kBlockK_0
=
FusedMoeTileShape
::
kBlockK_0
;
static_assert
(
Problem
::
kPadSeqLenQ
==
true
&&
Problem
::
kPadHeadDimQ
==
true
&&
static
constexpr
index_t
kWarpM_0
=
FusedMoeTileShape
::
kWarpM_0
;
Problem
::
kPadHeadDimV
==
true
);
static
constexpr
index_t
kWarpN_0
=
FusedMoeTileShape
::
kWarpN_0
;
static
constexpr
bool
kPadSeqLenQ
=
true
;
static
constexpr
index_t
kWarpK_0
=
FusedMoeTileShape
::
kWarpK_0
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
index_t
kBlockWarpsM_0
=
FusedMoeTileShape
::
kBlockWarpsM_0
;
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
index_t
kBlockWarpsN_0
=
FusedMoeTileShape
::
kBlockWarpsN_0
;
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
index_t
kBlockWarpsK_0
=
FusedMoeTileShape
::
kBlockWarpsK_0
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
index_t
kSubBlockM_0
=
FusedMoeTileShape
::
kSubBlockM_0
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
index_t
kSubBlockN_0
=
FusedMoeTileShape
::
kSubBlockN_0
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
static
constexpr
index_t
kSubBlockK_0
=
FusedMoeTileShape
::
kSubBlockK_0
;
static
constexpr
index_t
kWarpRepeatM_0
=
FusedMoeTileShape
::
kWarpRepeatM_0
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
static
constexpr
index_t
kWarpRepeatN_0
=
FusedMoeTileShape
::
kWarpRepeatN_0
;
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kWarpRepeatK_0
=
FusedMoeTileShape
::
kWarpRepeatK_0
;
static
constexpr
index_t
kAlignmentQ
=
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kBlockM_1
=
FusedMoeTileShape
::
kBlockM_1
;
static
constexpr
index_t
kAlignmentV
=
[]()
{
static
constexpr
index_t
kBlockN_1
=
FusedMoeTileShape
::
kBlockN_1
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
static
constexpr
index_t
kBlockK_1
=
FusedMoeTileShape
::
kBlockK_1
;
return
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kWarpM_1
=
FusedMoeTileShape
::
kWarpM_1
;
else
static
constexpr
index_t
kWarpN_1
=
FusedMoeTileShape
::
kWarpN_1
;
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kWarpK_1
=
FusedMoeTileShape
::
kWarpK_1
;
}();
static
constexpr
index_t
kBlockWarpsM_1
=
FusedMoeTileShape
::
kBlockWarpsM_1
;
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kBlockWarpsN_1
=
FusedMoeTileShape
::
kBlockWarpsN_1
;
static
constexpr
index_t
kAlignmentBias
=
static
constexpr
index_t
kBlockWarpsK_1
=
FusedMoeTileShape
::
kBlockWarpsK_1
;
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kSubBlockM_1
=
FusedMoeTileShape
::
kSubBlockM_1
;
static
constexpr
index_t
kSubBlockN_1
=
FusedMoeTileShape
::
kSubBlockN_1
;
#if CK_TILE_FMHA_FWD_FAST_EXP2
static
constexpr
index_t
kSubBlockK_1
=
FusedMoeTileShape
::
kSubBlockK_1
;
static
constexpr
auto
R_LOG2E
=
1.0
/
log2e_v
<
SaccDataType
>
;
static
constexpr
index_t
kWarpRepeatM_1
=
FusedMoeTileShape
::
kWarpRepeatM_1
;
#endif
static
constexpr
index_t
kWarpRepeatN_1
=
FusedMoeTileShape
::
kWarpRepeatN_1
;
static
constexpr
index_t
kWarpRepeatK_1
=
FusedMoeTileShape
::
kWarpRepeatK_1
;
using
MBlockType
=
decltype
(
GetMatrixCoreSwizzledBlockTIle_0
<
Problem
>
());
static
constexpr
index_t
kBlockNr_0
=
MBlockType
{}
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kBlockKr_0
=
MBlockType
{}
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kBlockWaveFlatten
=
MBlockType
{}
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kBlockPerCu
=
[]()
{
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
...
@@ -71,37 +77,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -71,37 +77,7 @@ struct BlockFmhaPipelineQRKSVSAsync
else
else
{
{
// minimize occupancy
// minimize occupancy
if
constexpr
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHasDropout
)
return
2
;
{
return
1
;
}
if
constexpr
(
kK0BlockLength
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
else
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
256
)
{
return
1
;
}
}
}
}();
}();
...
@@ -179,23 +155,261 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -179,23 +155,261 @@ struct BlockFmhaPipelineQRKSVSAsync
o_gtile_window_tmp
.
get_window_lengths
(),
o_gtile_window_tmp
.
get_window_lengths
(),
o_gtile_window_tmp
.
get_window_origin
(),
o_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>());
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>());
using
g_thread_type
=
decltype
(
load_tile
(
g_gtile_window
));
using
u_thread_type
=
decltype
(
load_tile
(
u_gtile_window
));
using
d_thread_type
=
decltype
(
load_tile
(
d_gtile_window
));
const
index_t
loops_0
=
(
dim_size
+
kBlockK_0
-
1
)
/
kBlockK_0
;
const
index_t
loops_1
=
(
dim_size
+
kBlockN_1
-
1
)
/
kBlockN_1
;
// auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset;
// issues_warps_lanes
auto
a_sst_0
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// issues_warps_lanes
auto
a_sst_1
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// m*k
auto
a_sld_0
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// m*k
auto
a_sld_1
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
g_thread_type
g_tile
[
2
];
using
WarpGemm0
=
Policy
::
GetWarpGemm0
<
Problem
>
();
using
WarpGemm1
=
Policy
::
GetWarpGemm1
<
Problem
>
();
auto
warp_gemm_0
=
WarpGemm0
{};
auto
warp_gemm_1
=
WarpGemm1
{};
// TODO: N fist, M next
const
index_t
i_mwarp_0
=
get_warp_id
()
/
kBlockWarpsN_0
;
// create and pre-cache a warp-window
auto
make_a_warp_windows
=
[
&
](
auto
a_sld_
)
{
// construct A-warp-window
auto
warp_window
=
make_tile_window
(
a_sld_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm0
::
kM
>
{},
number
<
WarpGemm0
::
kK
>
{}),
a_sld_
.
get_window_origin
()
+
multi_index
<
2
>
{
i_mwarp_0
*
WarpGemm0
::
kM
,
0
},
make_static_tile_distribution
(
typename
WarpGemm0
::
AWarpDstrEncoding
{}));
statically_indexed_array
<
statically_indexed_array
<
decltype
(
warp_window
),
kWarpRepeatK_0
>
,
kWarpRepeatM_0
>
ws
;
// pre-cache the warp windows
static_for
<
0
,
kWarpRepeatM_0
,
1
>
{}([
&
](
auto
i_m_iter
)
{
static_for
<
0
,
kWarpRepeatK_0
,
1
>
{}([
&
](
auto
i_k_iter
)
{
ws
(
i_m_iter
)(
i_k_iter
)
=
warp_window
;
move_tile_window
(
ws
(
i_m_iter
)(
i_k_iter
),
{
i_m_iter
*
NPerBlockPerIter
,
i_k_iter
*
KPerBlockPerIter
});
});
});
return
ws
;
};
auto
a_warp_windows_0
=
make_a_warp_windows
(
a_sld_0
);
auto
a_warp_windows_1
=
make_a_warp_windows
(
a_sld_1
);
constexpr
auto
true_v
=
bool_constant
<
true
>
{};
constexpr
auto
false_v
=
bool_constant
<
false
>
{};
auto
do_load_a0
=
[
&
](
auto
&
a_store_
,
auto
move_
)
{
async_load_tile
(
a_store_
,
a_gtile_window
);
if
constexpr
(
move_
)
move_tile_window
(
a_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockK_0
>
{}});
};
auto
do_load_b0
=
[
&
](
auto
&
g_tile_
,
auto
&
u_tile_
,
auto
move_
)
{
g_tile_
=
load_tile
(
g_gtile_window
);
u_tile_
=
load_tile
(
u_gtile_window
);
if
constexpr
(
move_
)
{
move_tile_window
(
g_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
move_tile_window
(
u_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
}
};
auto
do_load_b1
=
[
&
](
auto
&
d_tile_
,
auto
move_
)
{
d_tile_
=
load_tile
(
d_gtile_window
);
if
constexpr
(
move_
)
{
move_tile_window
(
d_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
}
};
// using AWarpTensor = typename decltype(warp_gemm_0)::AWarpTensor{};
// using CWarpTensor =
auto
acc_g
=
MakeCBlockTile_Gemm0
<
Problem
>
();
auto
acc_u
=
MakeCBlockTile_Gemm0
<
Problem
>
();
// async_load_tile(a_sst_0, a_gtile_window); move_tile_window(a_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); g_tile[0] = load_tile(g_gtile_window);
// move_tile_window(g_gtile_window, {number<0>{}, number<kBlockK_0>{}}); u_tile[0] =
// load_tile(u_gtile_window); move_tile_window(u_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); async_load_tile(a_sst_1, a_gtile_window);
// move_tile_window(a_gtile_window, {number<0>{}, number<kBlockK_0>{}}); g_tile[1] =
// load_tile(g_gtile_window); move_tile_window(g_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); u_tile[1] = load_tile(u_gtile_window);
// move_tile_window(u_gtile_window, {number<0>{}, number<kBlockK_0>{}});
auto
do_gemm_0
=
[
&
](
auto
&
acc_g_
,
auto
&
acc_u_
,
auto
&
a_windows_
,
auto
&
g_tile_
,
auto
&
u_tile_
)
{
// as_br (asmem, breg)
static_for
<
0
,
kWarpRepeatK_0
,
1
>
{}([
&
](
auto
i_k
)
{
static_for
<
0
,
kWarpRepeatM_0
,
1
>
{}([
&
](
auto
i_m
)
{
const
auto
w_a
=
load_tile
(
a_windows_
(
i_m
)(
i_k
));
static_for
<
0
,
kWarpRepeatN_0
,
1
>
{}([
&
](
auto
i_n
)
{
constexpr
auto
beg_acc
=
sequence
<
i_m
*
kSubBlockM_0
,
i_n
*
kSubBlockN_0
>
{};
constexpr
auto
end_acc
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
>
{};
// 3d indexing for permuted g/u/d
constexpr
auto
beg_b
=
sequence
<
i_m
*
kBlockWarpsM_0
,
i_n
*
kSubBlockN_0
,
0
>
{};
constexpr
auto
end_b
=
sequence
<
(
i_m
+
1
)
*
kBlockWarpsM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
,
0
>
{};
auto
w_acc_g
=
get_slice_tile
(
acc_g_
,
beg_acc
,
end_acc
);
auto
w_acc_u
=
get_slice_tile
(
acc_u_
,
beg_acc
,
end_acc
);
auto
w_g
=
get_slice_tile
(
g_tile_
,
beg_b
,
end_b
);
auto
w_u
=
get_slice_tile
(
u_tile_
,
beg_b
,
end_b
);
warp_gemm_0
(
w_acc_g
,
w_a
,
w_g
);
warp_gemm_0
(
w_acc_u
,
w_a
,
w_u
);
set_slice_tile
(
acc_g_
,
w_acc_g
,
beg_acc
,
end_acc
);
set_slice_tile
(
acc_u_
,
w_acc_u
,
beg_acc
,
end_acc
);
});
});
});
};
auto
do_gemm_1
=
[
&
](
auto
&
acc_d_
,
auto
&
a_tile_
,
auto
&
d_tile_
)
{
// ar_br (areg, breg)
static_for
<
0
,
kWarpRepeatK_1
,
1
>
{}([
&
](
auto
i_k
)
{
static_for
<
0
,
kWarpRepeatM_1
,
1
>
{}([
&
](
auto
i_m
)
{
constexpr
auto
beg_a
=
sequence
<
i_m
*
kSubBlockM_1
,
i_k
*
kSubBlockK_1
>
{};
constexpr
auto
end_a
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_1
,
(
i_k
+
1
)
*
kSubBlockK_1
>
{};
const
auto
w_a
=
get_slice_tile
(
a_tile_
,
beg_a
,
end_a
);
static_for
<
0
,
kWarpRepeatN_1
,
1
>
{}([
&
](
auto
i_n
)
{
constexpr
auto
beg_acc
=
sequence
<
i_m
*
kSubBlockM_0
,
i_n
*
kSubBlockN_0
>
{};
constexpr
auto
end_acc
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
>
{};
// 3d indexing for permuted g/u/d
constexpr
auto
beg_b
=
sequence
<
i_m
*
kBlockWarpsM_0
,
i_n
*
kSubBlockN_0
,
0
>
{};
constexpr
auto
end_b
=
sequence
<
(
i_m
+
1
)
*
kBlockWarpsM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
,
0
>
{};
auto
w_acc_d
=
get_slice_tile
(
acc_d_
,
beg_acc
,
end_acc
);
auto
w_d
=
get_slice_tile
(
d_tile_
,
beg_b
,
end_b
);
warp_gemm_1
(
w_acc_d
,
w_a
,
w_d
);
set_slice_tile
(
acc_d_
,
w_acc_d
,
beg_acc
,
end_acc
);
});
});
});
};
// start of pipeline
do_load_a0
(
a_sst_0
,
true_v
);
do_load_b0
(
g_tile
[
0
],
u_tile
[
0
],
true_v
);
do_load_a0
(
a_sst_1
,
true_v
);
do_load_b0
(
g_tile
[
1
],
u_tile
[
1
],
true_v
);
clear_tile
(
acc_g
);
clear_tile
(
acc_u
);
constexpr
auto
k_per_block_0
=
Problem
::
FusedMoeTileShape
::
kK_a
;
index_t
i_0
=
0
;
const
index_t
loops_0
=
(
dim_size
+
k_per_block_0
-
1
)
/
k_per_block_0
;
while
(
i_0
<
(
loops_0
-
2
))
{
// first buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_0
,
g_tile
[
0
],
u_tile
[
0
]);
do_load_a0
(
a_sst_0
,
true_v
);
do_load_b0
(
g_tile
[
0
],
u_tile
[
0
],
true_v
);
i_0
++
;
// second buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_1
,
g_tile
[
1
],
u_tile
[
1
]);
do_load_a0
(
a_sst_1
,
true_v
);
do_load_b0
(
g_tile
[
1
],
u_tile
[
1
],
true_v
);
i_0
++
;
}
// first buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_0
,
g_tile
[
0
],
u_tile
[
0
]);
// prefetch
d_thread_type
d_tile
[
2
];
do_load_b1
(
d_tile
[
0
],
true_v
);
do_load_b1
(
d_tile
[
1
],
true_v
);
// second buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_1
,
g_tile
[
1
],
u_tile
[
1
]);
// redice acc_g/u
constexpr
auto
acc_spans_0
=
decltype
(
acc_g
)
::
get_distributed_spans
();
sweep_tile_span
(
acc_spans_0
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
acc_spans_0
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
element_wise
::
Silu
{}(
acc_g
(
i_j_idx
),
acc_g
(
i_j_idx
));
acc_g
(
i_j_idx
)
*=
acc_u
(
i_j_idx
);
});
});
constexpr
auto
n_per_block_1
=
Problem
::
FusedMoeTileShape
::
kN_d
;
const
auto
y
=
[
&
]()
{
const
index_t
loops_1
=
(
dim_size
+
n_per_block_1
-
1
)
/
n_per_block_1
;
if
constexpr
(
std
::
is_same_v
<
YDataType
,
fp16_t
>
)
return
impl
::
cast_tile_pk_fp16_fp32
<
YDataType
>
(
acc_g
);
else
return
cast_tile
<
YDataType
>
(
acc_g
);
}();
auto
a_smem_ptr
=
reinterpret_cast
<
ADataType
*>
(
smem_ptr
)
+
a_smem_offset
;
auto
acc_d
=
MakeCBlockTile_Gemm1
<
Problem
>
();
clear_tile
(
acc_d
);
// TODO: reshuffle? 32x32x8 mfma can avlid LDS reshuffle
index_t
i_1
==
0
;
while
(
i_1
<
(
loops_1
-
2
))
{
// first buffer
do_gemm_1
(
acc_d
,
y
,
d_tile
[
0
]);
do_load_b1
(
d_tile
[
0
],
true_v
);
i_1
++
;
// second buffer
do_gemm_1
(
acc_d
,
y
,
d_tile
[
1
]);
do_load_b1
(
d_tile
[
1
],
true_v
);
i_1
++
;
}
auto
smem_0_window
=
make_tile_window
(
// first buffer
make_tensor_view
<
address_space_enum
::
lds
>
(
do_gemm_0
(
a_warp_windows_0
,
g_tile
[
0
],
g_tile
[
1
]);
smem_0
,
Policy
::
template
MakeLdsStoreBlockDescriptor_A
<
Problem
>()),
i_0
++
;
Policy
::
template
MakeLdsStoreBlockDescriptor_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
async_load_tile
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})));
// second buffer
for
(
index_t
i_0
=
0
;
i_0
<
loops_0
;
i_0
++
)
{}
do_gemm_0
(
a_warp_windows_1
,
g_tile
[
1
],
g_tile
[
1
]);
i_0
++
;
}
}
template
<
typename
QDramBlockWindowTmp
,
template
<
typename
QDramBlockWindowTmp
,
...
...
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
View file @
199f7f71
This diff is collapsed.
Click to expand it.
example/ck_tile/05_moe/fused_moe/kernel/fused_moe_kernel.hpp
→
example/ck_tile/05_moe/
include/ck_tile/ops/
fused_moe/kernel/fused_moe_kernel.hpp
View file @
199f7f71
This diff is collapsed.
Click to expand it.
example/ck_tile/05_moe/fused_moe/kernel/fused_moe_tile_partitioner.hpp
→
example/ck_tile/05_moe/
include/ck_tile/ops/
fused_moe/kernel/fused_moe_tile_partitioner.hpp
View file @
199f7f71
...
@@ -45,4 +45,31 @@ struct FusedMoeTilePartitioner_PersistentSplitD
...
@@ -45,4 +45,31 @@ struct FusedMoeTilePartitioner_PersistentSplitD
}
}
};
};
template
<
typename
FusedMoeTileShape_
>
struct
FusedMoeTilePartitioner_Linear
{
using
Shape
=
ck_tile
::
remove_cvref_t
<
FusedMoeTileShape_
>
;
static
constexpr
const
char
*
name
=
"2d"
;
// expert x hidden
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*num_sorted_tiles*/
,
ck_tile
::
index_t
/*hidden_size*/
))
{
index_t
i_n
=
blockIdx
.
x
;
index_t
i_m
=
blockIdx
.
y
;
return
ck_tile
::
make_tuple
(
i_m
,
i_n
);
}
// persistent
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
max_tokens
,
index_t
hidden_size
)
{
// TODO: this may need tuning
index_t
grids
=
num_cu
*
blocks_per_cu
;
index_t
ms
=
ck_tile
::
integer_divide_ceil
(
max_tokens
,
Shape
::
kBlockM_0
);
index_t
ns
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
Shape
::
kBlockN_0
);
return
dim3
(
ns
,
ms
,
1
);
}
};
}
// namespace ck_tile
}
// namespace ck_tile
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2.hpp
0 → 100644
View file @
199f7f71
This diff is collapsed.
Click to expand it.
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2_policy.hpp
0 → 100644
View file @
199f7f71
This diff is collapsed.
Click to expand it.
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_problem.hpp
→
example/ck_tile/05_moe/
include/ck_tile/ops/
fused_moe/pipeline/fused_moe_pipeline_problem.hpp
View file @
199f7f71
...
@@ -33,17 +33,7 @@ struct FusedMoePipelineProblem
...
@@ -33,17 +33,7 @@ struct FusedMoePipelineProblem
static
constexpr
index_t
kBlockSize
=
FusedMoeTileShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kBlockSize
=
FusedMoeTileShape
::
NumWarps
*
get_warp_size
();
// attributes from traits
// static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
// static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
// static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
// static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
// static constexpr auto BiasEnum = Traits::BiasEnum;
// static constexpr bool kStoreLSE = Traits::kStoreLSE;
// static constexpr bool kHasDropout = Traits::kHasDropout;
// static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
using
GateActivation
=
remove_cvref_t
<
typename
Traits
::
GateActivation_
>
;
//
using GateActivation = remove_cvref_t<typename Traits::GateActivation_>;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_tile_shape.hpp
→
example/ck_tile/05_moe/
include/ck_tile/ops/
fused_moe/pipeline/fused_moe_tile_shape.hpp
View file @
199f7f71
...
@@ -37,12 +37,11 @@ M_a | A | | | | | | | | |
...
@@ -37,12 +37,11 @@ M_a | A | | | | | | | | |
SILU x-----x +----------+
SILU x-----x +----------+
K_y = N_g = N_u dim
K_y = N_g = N_u dim
*/
*/
template
<
typename
BlockTile_
,
// sequence<M_a, N_g, N_
u
, K_a, N_d
template
<
typename
BlockTile_
,
// sequence<M_a, N_g, N_
sub0
, K_a, N_d
typename
Gemm0BlockWarps_
,
typename
Gemm0BlockWarps_
,
typename
Gemm0WarpTile_
,
typename
Gemm0WarpTile_
,
typename
Gemm1BlockWarps_
,
typename
Gemm1BlockWarps_
,
typename
Gemm1WarpTile_
,
typename
Gemm1WarpTile_
>
bool
IsDLayoutRowMajor_
>
struct
FusedMoeTileShape
struct
FusedMoeTileShape
{
{
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
...
@@ -58,25 +57,60 @@ struct FusedMoeTileShape
...
@@ -58,25 +57,60 @@ struct FusedMoeTileShape
static
constexpr
index_t
kM_a
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kM_a
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kN_g
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kN_g
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kN_u
=
BlockTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kN_u
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kK_a
=
BlockTile
::
at
(
number
<
3
>
{});
// e.g. N_g = 256, n_sub_gu=128, then we split blockN of G/U into 2 parts to loopover
static
constexpr
index_t
kN_d
=
BlockTile
::
at
(
number
<
4
>
{});
// this can help B matrix direct-to-register using too much vgpr issue
static_assert
(
kN_g
==
kN_u
);
static
constexpr
index_t
kN_sub0
=
BlockTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kK_a
=
BlockTile
::
at
(
number
<
3
>
{});
static
constexpr
index_t
kN_d
=
BlockTile
::
at
(
number
<
4
>
{});
// static_assert(kN_g == kN_u);
static
constexpr
index_t
kK_y
=
kN_g
;
static
constexpr
index_t
kK_y
=
kN_g
;
static
constexpr
index_t
kM_0
=
kM_a
;
static
constexpr
index_t
kBlockNSub_0
=
kN_sub0
;
// allow partial
static
constexpr
index_t
kN_0
=
kN_g
;
// note N will x2
static
constexpr
index_t
kBlockM_0
=
kM_a
;
static
constexpr
index_t
kK_0
=
kK_a
;
static
constexpr
index_t
kBlockN_0
=
kN_g
;
// note N will x2 in real pipeline for gemm-0
static
constexpr
index_t
kBlockK_0
=
kK_a
;
static
constexpr
index_t
kWarpM_0
=
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kWarpN_0
=
Gemm0WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kWarpK_0
=
Gemm0WarpTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kBlockWarpsM_0
=
Gemm0BlockWarps
::
at
(
numner
<
0
>
{});
static
constexpr
index_t
kBlockWarpsN_0
=
Gemm0BlockWarps
::
at
(
numner
<
1
>
{});
static
constexpr
index_t
kBlockWarpsK_0
=
Gemm0BlockWarps
::
at
(
numner
<
2
>
{});
static
constexpr
index_t
kSubBlockM_0
=
kWarpM_0
*
kBlockWarpsM_0
;
static
constexpr
index_t
kSubBlockN_0
=
kWarpN_0
*
kBlockWarpsN_0
;
static
constexpr
index_t
kSubBlockK_0
=
kWarpK_0
*
kBlockWarpsK_0
;
static_assert
(
kBlockM_0
%
kSubBlockM_0
==
0
);
static_assert
(
kBlockN_0
%
kSubBlockN_0
==
0
);
static_assert
(
kBlockK_0
%
kSubBlockK_0
==
0
);
static
constexpr
index_t
kWarpRepeatM_0
=
kBlockM_0
/
kSubBlockM_0
;
static
constexpr
index_t
kWarpRepeatN_0
=
kBlockN_0
/
kSubBlockN_0
;
static
constexpr
index_t
kWarpRepeatK_0
=
kBlockK_0
/
kSubBlockK_0
;
static
constexpr
index_t
kM_1
=
kM_0
;
static
constexpr
index_t
kBlockKSub_1
=
kBlockNSub_0
;
static
constexpr
index_t
kN_1
=
kN_d
;
static
constexpr
index_t
kBlockM_1
=
kM_a
;
static
constexpr
index_t
kK_1
=
kN_g
;
static
constexpr
index_t
kBlockN_1
=
kN_d
;
static
constexpr
index_t
kBlockK_1
=
kN_g
;
static
constexpr
index_t
kWarpM_1
=
Gemm1WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kWarpN_1
=
Gemm1WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kWarpK_1
=
Gemm1WarpTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kBlockWarpsM_1
=
Gemm1BlockWarps
::
at
(
numner
<
0
>
{});
static
constexpr
index_t
kBlockWarpsN_1
=
Gemm1BlockWarps
::
at
(
numner
<
1
>
{});
static
constexpr
index_t
kBlockWarpsK_1
=
Gemm1BlockWarps
::
at
(
numner
<
2
>
{});
static
constexpr
index_t
kSubBlockM_1
=
kWarpM_1
*
kBlockWarpsM_1
;
static
constexpr
index_t
kSubBlockN_1
=
kWarpN_1
*
kBlockWarpsN_1
;
static
constexpr
index_t
kSubBlockK_1
=
kWarpK_1
*
kBlockWarpsK_1
;
static_assert
(
kBlockM_1
%
kSubBlockM_1
==
0
);
static_assert
(
kBlockN_1
%
kSubBlockN_1
==
0
);
static_assert
(
kBlockK_1
%
kSubBlockK_1
==
0
);
static
constexpr
index_t
kWarpRepeatM_1
=
kBlockM_1
/
kSubBlockM_1
;
static
constexpr
index_t
kWarpRepeatN_1
=
kBlockN_1
/
kSubBlockN_1
;
static
constexpr
index_t
kWarpRepeatK_1
=
kBlockK_1
/
kSubBlockK_1
;
// d, rowmajor : hidden*dim, colmajor : dim*hidden (vLLM use this layout)
// d, rowmajor : hidden*dim, colmajor : dim*hidden (vLLM use this layout)
static
constexpr
bool
IsDLayoutRowMajor
=
IsDLayoutRowMajor_
;
//
static constexpr bool IsDLayoutRowMajor = IsDLayoutRowMajor_;
using
DLayout
=
std
::
conditional_t
<
IsDLayoutRowMajor
,
//
using DLayout = std::conditional_t<IsDLayoutRowMajor,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
,
//
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
//
ck_tile::tensor_layout::gemm::ColumnMajor>;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp
0 → 100644
View file @
199f7f71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp"
namespace
ck_tile
{
template
<
bool
DownPreShuffled_
=
false
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
,
index_t
OAtomic_
=
0
,
FusedMoeWeightPermuteEnum
WeightPermute_
=
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
>
struct
FusedMoeTraits
{
static
constexpr
bool
DownPreShuffled
=
DownPreShuffled_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
FusedMoeWeightPermuteEnum
WeightPermute
=
WeightPermute_
;
static
constexpr
index_t
OAtomic
=
OAtomic_
;
// 0-pack fp16/bf16 atomic, 1-fp32 atomic
};
}
// namespace ck_tile
example/ck_tile/05_moe/include/ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp
0 → 100644
View file @
199f7f71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
enum
class
FusedMoeWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv
=
2
,
// 0,1,3,4,2,5
permute_b_nr_kr_waveflatten
=
permute_b_nr_kr_kw_nw_kv
,
no_permute
=
999
,
};
}
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
199f7f71
...
@@ -616,11 +616,51 @@ struct buffer_store_if<1>
...
@@ -616,11 +616,51 @@ struct buffer_store_if<1>
}
}
};
};
CK_TILE_DEVICE
void
buffer_load_fence_raw
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
template
<
index_t
cnt
>
CK_TILE_DEVICE
void
buffer_load_fence_raw
(
number
<
cnt
>
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
#if 0
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
{
// const index_t origin_cnt = 0x0f70;
__builtin_amdgcn_s_waitcnt(0x0f70 | cnt);
}
#endif
template
<
index_t
cnt
>
CK_TILE_DEVICE
void
buffer_load_fence
(
number
<
cnt
>
)
{
/*
simm16, simm16[3:0] -> bits[3:0], simm16[15:14] -> bits[5:4]
*/
static_assert
(
cnt
<
64
);
constexpr
index_t
low
=
cnt
&
0xf
;
constexpr
index_t
hi
=
(
cnt
&
0x30
)
<<
14
;
constexpr
index_t
c
=
0x0f70
|
low
|
hi
;
__builtin_amdgcn_s_waitcnt
(
c
);
}
CK_TILE_DEVICE
void
wave_barrier
()
{
__builtin_amdgcn_s_barrier
();
}
CK_TILE_DEVICE
auto
async_load_fence_raw
(
index_t
cnt
=
0
)
{
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
}
template
<
index_t
cnt
>
CK_TILE_DEVICE
auto
async_load_fence
(
number
<
cnt
>
)
{
buffer_load_fence
(
number
<
cnt
>
{});
}
// clang-format off
// clang-format off
namespace
impl
{
namespace
impl
{
...
@@ -706,13 +746,13 @@ CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by)
...
@@ -706,13 +746,13 @@ CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by)
}
}
// clang-format on
// clang-format on
template
<
typename
...
T
>
template
<
typename
...
T
>
CK_TILE_DEVICE
void
buffer_load_fence
(
index_t
cnt
=
0
,
T
&
...
o
)
CK_TILE_DEVICE
void
buffer_load_fence
_raw
(
index_t
cnt
=
0
,
T
&
...
o
)
{
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
impl
::
insert_dummy_dep
(
o
...);
impl
::
insert_dummy_dep
(
o
...);
}
}
CK_TILE_DEVICE
void
buffer_store_fence
(
index_t
cnt
=
0
)
CK_TILE_DEVICE
void
buffer_store_fence
_raw
(
index_t
cnt
=
0
)
{
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
}
...
@@ -976,6 +1016,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
...
@@ -976,6 +1016,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int
soffset
,
// dst_wave_addr_offset
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_load_lds
(
int32x4_t
rsrc
,
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
,
index_t
size
,
index_t
voffset
,
index_t
soffset
,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
template
<
bool
pre_nop
=
false
>
template
<
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
int32x4_t
rsrc
,
int32x4_t
rsrc
,
...
@@ -998,10 +1048,12 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
...
@@ -998,10 +1048,12 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
:
"memory"
);
:
"memory"
);
}
}
#if 0
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
{
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
}
#endif
// memory coherency bit for buffer store/load instruction
// memory coherency bit for buffer store/load instruction
// check ISA manual for each GFX target
// check ISA manual for each GFX target
...
@@ -1365,6 +1417,45 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
...
@@ -1365,6 +1417,45 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
amd_async_buffer_load
(
CK_TILE_LDS_ADDR
T
*
smem
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_immediate_addr_offset
=
0
,
index_t
flag
=
0
,
bool_constant
<
oob_conditional_check
>
=
{})
{
static_assert
(
sizeof
(
T
)
*
N
==
4
,
"wrong! not implemented vector size"
);
if
constexpr
(
oob_conditional_check
)
{
index_t
v_offset
=
flag
?
v_offset
:
src_wave_buffer_resource
[
2
];
llvm_amdgcn_raw_buffer_load_lds
(
src_wave_buffer_resource
,
smem
,
// reinterpret_cast<CK_TILE_LDS_ADDR
// uint32_t*>(reinterpret_cast<uintptr_t>(smem)),
sizeof
(
uint32_t
),
v_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
{
llvm_amdgcn_raw_buffer_load_lds
(
src_wave_buffer_resource
,
smem
,
// reinterpret_cast<CK_TILE_LDS_ADDR
// uint32_t*>(reinterpret_cast<uintptr_t>(smem)),
sizeof
(
uint32_t
),
src_thread_addr_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
template
<
index_t
N
,
template
<
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
void
amd_buffer_store_impl_with_bytes
(
const
thread_buffer
<
int8_t
,
N
>
src_thread_data
,
CK_TILE_DEVICE
void
amd_buffer_store_impl_with_bytes
(
const
thread_buffer
<
int8_t
,
N
>
src_thread_data
,
...
@@ -2094,6 +2185,28 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
...
@@ -2094,6 +2185,28 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
}
}
// This version support buffer resource as input arg
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob
(
CK_TILE_LDS_ADDR
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_async_buffer_load
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
// buffer_store requires:
// buffer_store requires:
// 1) p_dst_wave must point to global memory
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// 2) p_dst_wave must be a wavewise pointer.
...
@@ -2221,16 +2334,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
...
@@ -2221,16 +2334,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
#endif
}
}
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_load_lds
(
int32x4_t
rsrc
,
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
,
index_t
size
,
index_t
voffset
,
index_t
soffset
,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
template
<
typename
T
,
index_t
NumElemsPerThread
>
template
<
typename
T
,
index_t
NumElemsPerThread
>
CK_TILE_DEVICE
void
amd_direct_load_global_to_lds
(
const
T
*
global_base_ptr
,
CK_TILE_DEVICE
void
amd_direct_load_global_to_lds
(
const
T
*
global_base_ptr
,
const
index_t
global_offset
,
const
index_t
global_offset
,
...
...
include/ck_tile/core/config.hpp
View file @
199f7f71
...
@@ -43,6 +43,20 @@
...
@@ -43,6 +43,20 @@
#define CK_TILE_HOST_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#endif
// implementing the "memory address space" attribute
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
#ifdef __HIPCC_
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
#else
#define CK_TILE_GENERIC_ADDR
#define CK_TILE_GLOBAL_ADDR
#define CK_TILE_LDS_ADDR
#define CK_TILE_BUF_RES_ADDR
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif
#endif
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
199f7f71
...
@@ -369,6 +369,31 @@ struct buffer_view<address_space_enum::global,
...
@@ -369,6 +369,31 @@ struct buffer_view<address_space_enum::global,
dst
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
dst
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
pre_nop
=
false
,
bool
pre_nop
=
false
,
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
199f7f71
...
@@ -49,6 +49,26 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
...
@@ -49,6 +49,26 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
// while creating the smem window, which can enable compiler properly detect the
// dependency if using multiple smem window (multiple buffer)
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load_tile
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
async_load
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
...
@@ -69,11 +89,6 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
...
@@ -69,11 +89,6 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
template
<
typename
WindowLengths
>
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
)
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
)
{
{
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
199f7f71
...
@@ -187,4 +187,30 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
...
@@ -187,4 +187,30 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
});
});
}
}
namespace
detail
{
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template
<
typename
X
,
typename
Y
>
struct
is_similiar_distributed_tensor
{
static
constexpr
bool
value
=
false
;
};
template
<
typename
TypeX
,
typename
DistX
,
typename
TypeY
,
typename
DistY
>
struct
is_similiar_distributed_tensor
<
static_distributed_tensor
<
TypeX
,
DistX
>
,
static_distributed_tensor
<
TypeY
,
DistY
>>
{
using
Tx
=
static_distributed_tensor
<
TypeX
,
DistX
>
;
using
Ty
=
static_distributed_tensor
<
TypeY
,
DistY
>
;
static
constexpr
bool
value
=
std
::
is_same_v
<
typename
Tx
::
DataType
,
typename
Ty
::
DataType
>
&&
Tx
::
get_thread_buffer_size
()
==
Ty
::
get_thread_buffer_size
();
};
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_similiar_distributed_tensor_v
=
is_similiar_distributed_tensor
<
X
,
Y
>::
value
;
}
// namespace detail
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_view.hpp
View file @
199f7f71
...
@@ -104,6 +104,23 @@ struct tensor_view
...
@@ -104,6 +104,23 @@ struct tensor_view
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
template
<
typename
X
,
bool
pre_nop
=
false
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
199f7f71
...
@@ -495,6 +495,74 @@ struct tile_window_with_static_distribution
...
@@ -495,6 +495,74 @@ struct tile_window_with_static_distribution
});
});
}
}
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}));
constexpr
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
-
size_per_buf
;
constexpr
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
});
});
}
template
<
bool
oob_conditional_check
=
true
>
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
...
...
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
199f7f71
...
@@ -39,7 +39,7 @@ struct Default2DEpilogue
...
@@ -39,7 +39,7 @@ struct Default2DEpilogue
if
constexpr
(
kPadM
||
kPadN
)
if
constexpr
(
kPadM
||
kPadN
)
{
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
buffer_store_fence
_raw
();
}
}
else
else
{
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
View file @
199f7f71
...
@@ -274,8 +274,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -274,8 +274,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
store_tile
(
lse_acc_dram_window_tmp
,
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
buffer_load_fence
_raw
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
// Note: here occ are all cleard, return it
return
o_acc
;
return
o_acc
;
...
@@ -315,7 +315,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -315,7 +315,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
buffer_load_fence
_raw
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
...
@@ -338,7 +338,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -338,7 +338,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if
constexpr
(
i_k0
<
k0_loops
-
1
)
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
async_load_fence
_raw
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
...
@@ -360,7 +360,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -360,7 +360,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if
constexpr
(
k0_loops
<=
2
)
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence
();
async_load_fence
_raw
();
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment