Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4412a07d
"include/vscode:/vscode.git/clone" did not exist on "ceebf3065329ece38bfd1f03d1e343f34c09ef71"
Commit
4412a07d
authored
Sep 01, 2024
by
carlushuang
Browse files
remove extra files
parent
54d3e2f1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
1667 deletions
+0
-1667
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_permute_enum.hpp
...tile/05_moe/fused_moe/pipeline/fused_moe_permute_enum.hpp
+0
-15
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
.../ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
+0
-1061
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
...e/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
+0
-591
No files found.
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_permute_enum.hpp
deleted
100644 → 0
View file @
54d3e2f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
enum
class
FusedMoeWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv
=
2
,
// 0,1,3,4,2,5
permute_b_nr_kr_waveflatten
=
permute_b_nr_kr_kw_nw_kv
,
no_permute
=
999
,
};
}
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
deleted
100644 → 0
View file @
54d3e2f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
>
struct
FusedMoePipeline
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
GDataType
=
remove_cvref_t
<
typename
Problem
::
GDataType
>
;
using
UDataType
=
remove_cvref_t
<
typename
Problem
::
UDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
Problem
::
DDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
ScaleDataType
>
;
using
FusedMoeTileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockM_0
=
FusedMoeTileShape
::
kBlockM_0
;
static
constexpr
index_t
kBlockN_0
=
FusedMoeTileShape
::
kBlockN_0
;
static
constexpr
index_t
kBlockK_0
=
FusedMoeTileShape
::
kBlockK_0
;
static
constexpr
index_t
kWarpM_0
=
FusedMoeTileShape
::
kWarpM_0
;
static
constexpr
index_t
kWarpN_0
=
FusedMoeTileShape
::
kWarpN_0
;
static
constexpr
index_t
kWarpK_0
=
FusedMoeTileShape
::
kWarpK_0
;
static
constexpr
index_t
kBlockWarpsM_0
=
FusedMoeTileShape
::
kBlockWarpsM_0
;
static
constexpr
index_t
kBlockWarpsN_0
=
FusedMoeTileShape
::
kBlockWarpsN_0
;
static
constexpr
index_t
kBlockWarpsK_0
=
FusedMoeTileShape
::
kBlockWarpsK_0
;
static
constexpr
index_t
kSubBlockM_0
=
FusedMoeTileShape
::
kSubBlockM_0
;
static
constexpr
index_t
kSubBlockN_0
=
FusedMoeTileShape
::
kSubBlockN_0
;
static
constexpr
index_t
kSubBlockK_0
=
FusedMoeTileShape
::
kSubBlockK_0
;
static
constexpr
index_t
kWarpRepeatM_0
=
FusedMoeTileShape
::
kWarpRepeatM_0
;
static
constexpr
index_t
kWarpRepeatN_0
=
FusedMoeTileShape
::
kWarpRepeatN_0
;
static
constexpr
index_t
kWarpRepeatK_0
=
FusedMoeTileShape
::
kWarpRepeatK_0
;
static
constexpr
index_t
kBlockM_1
=
FusedMoeTileShape
::
kBlockM_1
;
static
constexpr
index_t
kBlockN_1
=
FusedMoeTileShape
::
kBlockN_1
;
static
constexpr
index_t
kBlockK_1
=
FusedMoeTileShape
::
kBlockK_1
;
static
constexpr
index_t
kWarpM_1
=
FusedMoeTileShape
::
kWarpM_1
;
static
constexpr
index_t
kWarpN_1
=
FusedMoeTileShape
::
kWarpN_1
;
static
constexpr
index_t
kWarpK_1
=
FusedMoeTileShape
::
kWarpK_1
;
static
constexpr
index_t
kBlockWarpsM_1
=
FusedMoeTileShape
::
kBlockWarpsM_1
;
static
constexpr
index_t
kBlockWarpsN_1
=
FusedMoeTileShape
::
kBlockWarpsN_1
;
static
constexpr
index_t
kBlockWarpsK_1
=
FusedMoeTileShape
::
kBlockWarpsK_1
;
static
constexpr
index_t
kSubBlockM_1
=
FusedMoeTileShape
::
kSubBlockM_1
;
static
constexpr
index_t
kSubBlockN_1
=
FusedMoeTileShape
::
kSubBlockN_1
;
static
constexpr
index_t
kSubBlockK_1
=
FusedMoeTileShape
::
kSubBlockK_1
;
static
constexpr
index_t
kWarpRepeatM_1
=
FusedMoeTileShape
::
kWarpRepeatM_1
;
static
constexpr
index_t
kWarpRepeatN_1
=
FusedMoeTileShape
::
kWarpRepeatN_1
;
static
constexpr
index_t
kWarpRepeatK_1
=
FusedMoeTileShape
::
kWarpRepeatK_1
;
using
MBlockType
=
decltype
(
GetMatrixCoreSwizzledBlockTIle_0
<
Problem
>
());
static
constexpr
index_t
kBlockNr_0
=
MBlockType
{}
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kBlockKr_0
=
MBlockType
{}
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kBlockWaveFlatten
=
MBlockType
{}
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
return
2
;
}
}();
static
constexpr
const
char
*
name
=
"qr_async"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetAIndex
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOIndex
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
template
<
typename
AGlobalTensorView
,
typename
GGlobalTileWindow
,
typename
UGlobalTileWindow
,
typename
DGlobalTileWindow
,
typename
OGlobalTensorView
>
CK_TILE_DEVICE
auto
operator
()(
const
AGlobalTensorView
&
a_gtile_window_tmp
,
const
GGlobalTileWindow
&
g_gtile_window_tmp
,
const
UGlobalTileWindow
&
u_gtile_window_tmp
,
const
DGlobalTileWindow
&
d_gtile_window_tmp
,
OGlobalTensorView
&
o_gtile_window_tmp
,
// const void * sorted_weight_ptr,
ScaleDataType
scale
,
CK_TILE_LDS_ADDR
void
*
smem_0
,
CK_TILE_LDS_ADDR
void
*
smem_1
,
index_t
dim_size
,
index_t
hidden_size
)
{
constexpr
auto
gemm_0
=
Policy
::
template
GetGemm0
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetGemm1
<
Problem
>();
auto
a_gtile_window
=
make_tile_window
(
a_gtile_window_tmp
.
get_bottom_tensor_view
(),
a_gtile_window_tmp
.
get_window_lengths
(),
a_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
auto
g_gtile_window
=
make_tile_window
(
g_gtile_window_tmp
.
get_bottom_tensor_view
(),
g_gtile_window_tmp
.
get_window_lengths
(),
g_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>());
auto
u_gtile_window
=
make_tile_window
(
u_gtile_window_tmp
.
get_bottom_tensor_view
(),
u_gtile_window_tmp
.
get_window_lengths
(),
u_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_U
<
Problem
>());
auto
d_gtile_window
=
make_tile_window
(
d_gtile_window_tmp
.
get_bottom_tensor_view
(),
d_gtile_window_tmp
.
get_window_lengths
(),
d_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>());
auto
o_gtile_window
=
make_tile_window
(
o_gtile_window_tmp
.
get_bottom_tensor_view
(),
o_gtile_window_tmp
.
get_window_lengths
(),
o_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>());
using
g_thread_type
=
decltype
(
load_tile
(
g_gtile_window
));
using
u_thread_type
=
decltype
(
load_tile
(
u_gtile_window
));
using
d_thread_type
=
decltype
(
load_tile
(
d_gtile_window
));
const
index_t
loops_0
=
(
dim_size
+
kBlockK_0
-
1
)
/
kBlockK_0
;
const
index_t
loops_1
=
(
dim_size
+
kBlockN_1
-
1
)
/
kBlockN_1
;
// auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset;
// issues_warps_lanes
auto
a_sst_0
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// issues_warps_lanes
auto
a_sst_1
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// m*k
auto
a_sld_0
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// m*k
auto
a_sld_1
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
g_thread_type
g_tile
[
2
];
using
WarpGemm0
=
Policy
::
GetWarpGemm0
<
Problem
>
();
using
WarpGemm1
=
Policy
::
GetWarpGemm1
<
Problem
>
();
auto
warp_gemm_0
=
WarpGemm0
{};
auto
warp_gemm_1
=
WarpGemm1
{};
// TODO: N fist, M next
const
index_t
i_mwarp_0
=
get_warp_id
()
/
kBlockWarpsN_0
;
// create and pre-cache a warp-window
auto
make_a_warp_windows
=
[
&
](
auto
a_sld_
)
{
// construct A-warp-window
auto
warp_window
=
make_tile_window
(
a_sld_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm0
::
kM
>
{},
number
<
WarpGemm0
::
kK
>
{}),
a_sld_
.
get_window_origin
()
+
multi_index
<
2
>
{
i_mwarp_0
*
WarpGemm0
::
kM
,
0
},
make_static_tile_distribution
(
typename
WarpGemm0
::
AWarpDstrEncoding
{}));
statically_indexed_array
<
statically_indexed_array
<
decltype
(
warp_window
),
kWarpRepeatK_0
>
,
kWarpRepeatM_0
>
ws
;
// pre-cache the warp windows
static_for
<
0
,
kWarpRepeatM_0
,
1
>
{}([
&
](
auto
i_m_iter
)
{
static_for
<
0
,
kWarpRepeatK_0
,
1
>
{}([
&
](
auto
i_k_iter
)
{
ws
(
i_m_iter
)(
i_k_iter
)
=
warp_window
;
move_tile_window
(
ws
(
i_m_iter
)(
i_k_iter
),
{
i_m_iter
*
NPerBlockPerIter
,
i_k_iter
*
KPerBlockPerIter
});
});
});
return
ws
;
};
auto
a_warp_windows_0
=
make_a_warp_windows
(
a_sld_0
);
auto
a_warp_windows_1
=
make_a_warp_windows
(
a_sld_1
);
constexpr
auto
true_v
=
bool_constant
<
true
>
{};
constexpr
auto
false_v
=
bool_constant
<
false
>
{};
auto
do_load_a0
=
[
&
](
auto
&
a_store_
,
auto
move_
)
{
async_load_tile
(
a_store_
,
a_gtile_window
);
if
constexpr
(
move_
)
move_tile_window
(
a_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockK_0
>
{}});
};
auto
do_load_b0
=
[
&
](
auto
&
g_tile_
,
auto
&
u_tile_
,
auto
move_
)
{
g_tile_
=
load_tile
(
g_gtile_window
);
u_tile_
=
load_tile
(
u_gtile_window
);
if
constexpr
(
move_
)
{
move_tile_window
(
g_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
move_tile_window
(
u_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
}
};
auto
do_load_b1
=
[
&
](
auto
&
d_tile_
,
auto
move_
)
{
d_tile_
=
load_tile
(
d_gtile_window
);
if
constexpr
(
move_
)
{
move_tile_window
(
d_gtile_window
,
{
number
<
0
>
{},
number
<
kBlockKr_0
>
{},
number
<
0
>
{}});
}
};
// using AWarpTensor = typename decltype(warp_gemm_0)::AWarpTensor{};
// using CWarpTensor =
auto
acc_g
=
MakeCBlockTile_Gemm0
<
Problem
>
();
auto
acc_u
=
MakeCBlockTile_Gemm0
<
Problem
>
();
// async_load_tile(a_sst_0, a_gtile_window); move_tile_window(a_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); g_tile[0] = load_tile(g_gtile_window);
// move_tile_window(g_gtile_window, {number<0>{}, number<kBlockK_0>{}}); u_tile[0] =
// load_tile(u_gtile_window); move_tile_window(u_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); async_load_tile(a_sst_1, a_gtile_window);
// move_tile_window(a_gtile_window, {number<0>{}, number<kBlockK_0>{}}); g_tile[1] =
// load_tile(g_gtile_window); move_tile_window(g_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); u_tile[1] = load_tile(u_gtile_window);
// move_tile_window(u_gtile_window, {number<0>{}, number<kBlockK_0>{}});
auto
do_gemm_0
=
[
&
](
auto
&
acc_g_
,
auto
&
acc_u_
,
auto
&
a_windows_
,
auto
&
g_tile_
,
auto
&
u_tile_
)
{
// as_br (asmem, breg)
static_for
<
0
,
kWarpRepeatK_0
,
1
>
{}([
&
](
auto
i_k
)
{
static_for
<
0
,
kWarpRepeatM_0
,
1
>
{}([
&
](
auto
i_m
)
{
const
auto
w_a
=
load_tile
(
a_windows_
(
i_m
)(
i_k
));
static_for
<
0
,
kWarpRepeatN_0
,
1
>
{}([
&
](
auto
i_n
)
{
constexpr
auto
beg_acc
=
sequence
<
i_m
*
kSubBlockM_0
,
i_n
*
kSubBlockN_0
>
{};
constexpr
auto
end_acc
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
>
{};
// 3d indexing for permuted g/u/d
constexpr
auto
beg_b
=
sequence
<
i_m
*
kBlockWarpsM_0
,
i_n
*
kSubBlockN_0
,
0
>
{};
constexpr
auto
end_b
=
sequence
<
(
i_m
+
1
)
*
kBlockWarpsM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
,
0
>
{};
auto
w_acc_g
=
get_slice_tile
(
acc_g_
,
beg_acc
,
end_acc
);
auto
w_acc_u
=
get_slice_tile
(
acc_u_
,
beg_acc
,
end_acc
);
auto
w_g
=
get_slice_tile
(
g_tile_
,
beg_b
,
end_b
);
auto
w_u
=
get_slice_tile
(
u_tile_
,
beg_b
,
end_b
);
warp_gemm_0
(
w_acc_g
,
w_a
,
w_g
);
warp_gemm_0
(
w_acc_u
,
w_a
,
w_u
);
set_slice_tile
(
acc_g_
,
w_acc_g
,
beg_acc
,
end_acc
);
set_slice_tile
(
acc_u_
,
w_acc_u
,
beg_acc
,
end_acc
);
});
});
});
};
auto
do_gemm_1
=
[
&
](
auto
&
acc_d_
,
auto
&
a_tile_
,
auto
&
d_tile_
)
{
// ar_br (areg, breg)
static_for
<
0
,
kWarpRepeatK_1
,
1
>
{}([
&
](
auto
i_k
)
{
static_for
<
0
,
kWarpRepeatM_1
,
1
>
{}([
&
](
auto
i_m
)
{
constexpr
auto
beg_a
=
sequence
<
i_m
*
kSubBlockM_1
,
i_k
*
kSubBlockK_1
>
{};
constexpr
auto
end_a
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_1
,
(
i_k
+
1
)
*
kSubBlockK_1
>
{};
const
auto
w_a
=
get_slice_tile
(
a_tile_
,
beg_a
,
end_a
);
static_for
<
0
,
kWarpRepeatN_1
,
1
>
{}([
&
](
auto
i_n
)
{
constexpr
auto
beg_acc
=
sequence
<
i_m
*
kSubBlockM_0
,
i_n
*
kSubBlockN_0
>
{};
constexpr
auto
end_acc
=
sequence
<
(
i_m
+
1
)
*
kSubBlockM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
>
{};
// 3d indexing for permuted g/u/d
constexpr
auto
beg_b
=
sequence
<
i_m
*
kBlockWarpsM_0
,
i_n
*
kSubBlockN_0
,
0
>
{};
constexpr
auto
end_b
=
sequence
<
(
i_m
+
1
)
*
kBlockWarpsM_0
,
(
i_n
+
1
)
*
kSubBlockN_0
,
0
>
{};
auto
w_acc_d
=
get_slice_tile
(
acc_d_
,
beg_acc
,
end_acc
);
auto
w_d
=
get_slice_tile
(
d_tile_
,
beg_b
,
end_b
);
warp_gemm_1
(
w_acc_d
,
w_a
,
w_d
);
set_slice_tile
(
acc_d_
,
w_acc_d
,
beg_acc
,
end_acc
);
});
});
});
};
// start of pipeline
do_load_a0
(
a_sst_0
,
true_v
);
do_load_b0
(
g_tile
[
0
],
u_tile
[
0
],
true_v
);
do_load_a0
(
a_sst_1
,
true_v
);
do_load_b0
(
g_tile
[
1
],
u_tile
[
1
],
true_v
);
clear_tile
(
acc_g
);
clear_tile
(
acc_u
);
index_t
i_0
=
0
;
while
(
i_0
<
(
loops_0
-
2
))
{
// first buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_0
,
g_tile
[
0
],
u_tile
[
0
]);
do_load_a0
(
a_sst_0
,
true_v
);
do_load_b0
(
g_tile
[
0
],
u_tile
[
0
],
true_v
);
i_0
++
;
// second buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_1
,
g_tile
[
1
],
u_tile
[
1
]);
do_load_a0
(
a_sst_1
,
true_v
);
do_load_b0
(
g_tile
[
1
],
u_tile
[
1
],
true_v
);
i_0
++
;
}
// first buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_0
,
g_tile
[
0
],
u_tile
[
0
]);
// prefetch
d_thread_type
d_tile
[
2
];
do_load_b1
(
d_tile
[
0
],
true_v
);
do_load_b1
(
d_tile
[
1
],
true_v
);
// second buffer
do_gemm_0
(
acc_g
,
acc_u
,
a_warp_windows_1
,
g_tile
[
1
],
u_tile
[
1
]);
// redice acc_g/u
constexpr
auto
acc_spans_0
=
decltype
(
acc_g
)
::
get_distributed_spans
();
sweep_tile_span
(
acc_spans_0
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
acc_spans_0
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
element_wise
::
Silu
{}(
acc_g
(
i_j_idx
),
acc_g
(
i_j_idx
));
acc_g
(
i_j_idx
)
*=
acc_u
(
i_j_idx
);
});
});
const
auto
y
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
YDataType
,
fp16_t
>
)
return
impl
::
cast_tile_pk_fp16_fp32
<
YDataType
>
(
acc_g
);
else
return
cast_tile
<
YDataType
>
(
acc_g
);
}();
auto
acc_d
=
MakeCBlockTile_Gemm1
<
Problem
>
();
clear_tile
(
acc_d
);
// TODO: reshuffle? 32x32x8 mfma can avlid LDS reshuffle
index_t
i_1
==
0
;
while
(
i_1
<
(
loops_1
-
2
))
{
// first buffer
do_gemm_1
(
acc_d
,
y
,
d_tile
[
0
]);
do_load_b1
(
d_tile
[
0
],
true_v
);
i_1
++
;
// second buffer
do_gemm_1
(
acc_d
,
y
,
d_tile
[
1
]);
do_load_b1
(
d_tile
[
1
],
true_v
);
i_1
++
;
}
// first buffer
do_gemm_0
(
a_warp_windows_0
,
g_tile
[
0
],
g_tile
[
1
]);
i_0
++
;
// second buffer
do_gemm_0
(
a_warp_windows_1
,
g_tile
[
1
],
g_tile
[
1
]);
i_0
++
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KElementFunction
&
/*k_element_func*/
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
constexpr
auto
LdsSeq
=
Policy
::
template
GetLdsBufferSequence
<
Problem
>();
// K tile in LDS
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds_store
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto
k_lds_load
=
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
)),
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
#else
auto
k_lds_Load_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>());
auto
k_lds_load
=
make_tile_window
(
k_lds_Load_view
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
#endif
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
q_dram_window
.
init_raw
();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
// reg = copy(some_tensor_vew)
set_tile
(
q
,
number
<
0
>
{});
// use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
if
(
num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
auto
lse
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
buffer_load_fence_raw
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return
o_acc
;
}
__builtin_amdgcn_sched_barrier
(
0
);
// make sure sched_barrier(0) for this check
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
k_dram_window
.
init_raw
();
constexpr
auto
k_oob_ck
=
bool_constant
<
true
>
{};
constexpr
auto
k_pre_np
=
[
&
]()
{
if
constexpr
(
kPadSeqLenK
&&
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHasDropout
)))
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence_raw
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
1
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence_raw
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
i_k0
>
{})
>
{}]);
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
i_k0
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence_raw
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
auto
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
{
// tail
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load
[
number
<
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
>
{}]);
#else
get_slice_tile
(
k_lds_load
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{}))
*
kN0
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
-
1
>
{})
+
1
)
*
kN0
,
kK0
>
{}));
#endif
}
__builtin_amdgcn_sched_barrier
(
1
);
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
+=
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#else
x
+=
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
get_tile_distribution
());
// Pcompute{j}
__builtin_amdgcn_sched_barrier
(
0x7F
);
// store & prefetch next v, after the max reduction
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store the prefetch
}
if
constexpr
(
k1_loops
>
1
)
{
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
// will have scratch if move this right after load_tile(v_dram)...
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
}
__builtin_amdgcn_sched_barrier
(
0
);
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
if
constexpr
(
kHasDropout
)
{
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
randval_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
const
auto
p
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
PDataType
,
fp16_t
>
)
return
impl
::
cast_tile_pk_fp16_fp32
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
else
return
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
}();
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
{
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
}
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
i_k1
+
1
>
{})
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
}
if
constexpr
(
i_k1
<
k1_loops
-
1
)
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
i_total_loops
++
;
if
(
i_total_loops
<
num_total_loop
)
{
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
.
set_window_origin
(
k_dram_block_window
.
get_window_origin
());
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{}))
*
kN1
,
0
>
{},
sequence
<
(
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
1
>
{})
+
1
)
*
kN1
,
kK1
>
{}));
}
}
while
(
i_total_loops
<
num_total_loop
);
// store lse
if
constexpr
(
kStoreLSE
)
{
auto
lse
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_spans
=
decltype
(
lse
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
else
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
#else
lse
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
});
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
identity
{},
v_dram_block_window_tmp
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
mask
,
position_encoding
,
scale_s
,
smem_ptr
,
dropout
);
}
};
}
// namespace ck_tile
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
deleted
100644 → 0
View file @
54d3e2f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
struct
FusedMoePipelinePolicy
{
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetAsyncCopyDwords
()
{
// TODO:
return
1
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_A
()
{
// using async
static
constexpr
index_t
copy_bytes
=
4
*
GetAsyncCopyDwords
();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
ADataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_G
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
GDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_U
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
UDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_D
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
DDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
DataType_
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack
()
{
// TODO: this is for 3d layout
return
16
/
sizeof
(
remove_cvref_t
<
typename
Problem
::
DataType_
>
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack_A
()
{
return
GetSmemKPack
<
typename
Problem
::
ADataType
>
();
}
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK
()
{
constexpr
index_t
K_vec
=
Alignment
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"not not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
,
M_lan
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
// optimized version for async
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK_Async
()
{
constexpr
index_t
K_vec
=
Alignment
;
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<=
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"do not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
// NOTE: no swap, but hard to avoid LDS bank conflict
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
// NOTE: swapped for LDS load bank conflict free
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
,
M_lan
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_0
()
{
if
constexpr
(
Problem
::
Traits
::
PermuteStyle
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
using
WarpGemm
=
GetWarpGemm0
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_0
;
constexpr
index_t
KPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
Kv
=
GetAlignment_G
<
{
Problem
}
>
();
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
sequence
<
Nr
,
Kr
,
Kw
*
Nw
*
Kv
>
{};
// 3D
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_1
()
{
if
constexpr
(
Problem
::
Traits
::
PermuteStyle
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
using
WarpGemm
=
GetWarpGemm1
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_1
;
constexpr
index_t
KPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_1
;
constexpr
index_t
Kv
=
GetAlignment_G
<
{
Problem
}
>
();
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
sequence
<
Nr
,
Kr
,
Kw
*
Nw
*
Kv
>
{};
// 3D
}
}
// Caution: this will require global memory pre-shuffled to follow the mfma layout
// to maximize the L1/L2 channel while skip LDS
template
<
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
WavesPerBlock_N
,
index_t
WavesPerBlock_K
,
typename
WarpGemm
,
index_t
Alignment
,
FusedMoeWeightPermuteEnum
PermuteStyle
=
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_MatrixCore_Swizzled
()
{
static_assert
(
Alignment
%
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKPerLane
==
0
);
if
constexpr
(
PermuteStyle
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr
index_t
Kv
=
Alignment
;
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
constexpr
index_t
Nr_p
=
WavesPerBlock_N
;
constexpr
index_t
Kr_p
=
WavesPerBlock_K
;
constexpr
index_t
Nr_y
=
Nr
/
Nr_p
;
constexpr
index_t
Kr_y
=
Kr
/
Kr_p
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
// 0
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple
<
sequence
<
Nr_y
,
Nr_p
>
,
sequence
<
Kr_y
,
Kr_p
>
,
sequence
<
Kw
,
Nw
,
Kv
>>
,
// Nr_p, Kr_p Kw Nw
tuple
<
sequence
<
1
,
2
>
,
sequence
<
3
,
3
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
0
,
1
>>
,
// Nr_y Kr_y Kv
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
2
>>
{});
// clang-format on
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_A
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kM_a
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
Alignment
=
GetAlignment_A
<
Problem
>
();
return
MakeGlobalTileDistribution_SimpleMxK_Async
<
kMPerBlock
,
kKPerBlock
,
NumWarps
,
Alignment
>
();
}
template
<
typename
Problem
,
index_t
NSplits
=
2
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
(
number
<
NSplits
>
=
{})
{
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsN_0
;
constexpr
index_t
WavesPerBlock_K
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsK_0
;
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm0
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_G
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
WavesPerBlock_K
,
WarpGemm
,
Alignment
,
PermuteStype
>
();
}
}
template
<
typename
Problem
,
index_t
NSplits
=
2
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_U
(
number
<
NSplits
>
=
{})
{
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsN_0
;
constexpr
index_t
WavesPerBlock_K
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsK_0
;
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm0
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_U
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
WavesPerBlock_K
,
WarpGemm
,
Alignment
,
PermuteStype
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_D
()
{
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockN_1
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_1
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsN_1
;
constexpr
index_t
WavesPerBlock_K
=
Problem
::
FusedMoeTileShape
::
kBlockWarpsK_1
;
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm1
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_D
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
WavesPerBlock_K
,
WarpGemm
,
Alignment
,
PermuteStype
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreDesc_A
()
{
// A async->LDS
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockM_0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
kVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
kPad
=
KPack
;
// pad between warps
static_assert
(
kKPerBlock
%
kVector
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
kVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
kMPerBlock
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_merge_transform
(
make_tuple
(
number
<
wavesPerM
>
{},
number
<
wavesPerK
>
{})),
make_merge_transform
(
make_tuple
(
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
kMPerBlock
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
kKPerBlock
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsLoadDesc_A
()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockM_0
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kBlockK_0
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
kVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
kPad
=
KPack
;
// pad between warps
static_assert
(
kKPerBlock
%
kVector
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
kVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>=
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
kMPerBlock
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
wavesPerM
>
{})),
make_merge_transform
(
make_tuple
(
number
<
wavesPerK
>
{},
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
kMPerBlock
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
kKPerBlock
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm0
()
{
return
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
GDataType
,
typename
Problem
::
AccDataType
,
Problem
::
FusedMoeTileShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm0WarpTile
::
at
(
number
<
2
>
{}),
true
/*TransposeC*/
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm1
()
{
return
WarpGemmMfmaDispatcher
<
typename
Problem
::
YDataType
,
typename
Problem
::
DDataType
,
typename
Problem
::
AccDataType
,
Problem
::
FusedMoeTileShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
/*TransposeC*/
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
constexpr
auto
MakeCBlockTile_Gemm0
()
const
{
using
TileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
constexpr
index_t
BlockWarpsM
=
TileShape
::
kBlockWarpsM_0
;
constexpr
index_t
BlockWarpsN
=
TileShape
::
kBlockWarpsN_0
;
constexpr
index_t
WarpRepeatM
=
TileShape
::
kWarpRepeatM_0
;
constexpr
index_t
WarpRepeatN
=
TileShape
::
kWarpRepeatN_0
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm0
<
Problem
>
())
>
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
WarpRepeatM
,
BlockWarpsM
>
,
sequence
<
WarpRepeatN
,
BlockWarpsN
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
constexpr
auto
MakeCBlockTile_Gemm1
()
const
{
using
TileShape
=
remove_cvref_t
<
typename
Problem
::
FusedMoeTileShape
>
;
constexpr
index_t
BlockWarpsM
=
TileShape
::
kBlockWarpsM_1
;
constexpr
index_t
BlockWarpsN
=
TileShape
::
kBlockWarpsN_1
;
constexpr
index_t
WarpRepeatM
=
TileShape
::
kWarpRepeatM_1
;
constexpr
index_t
WarpRepeatN
=
TileShape
::
kWarpRepeatN_1
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
WarpRepeatM
,
BlockWarpsM
>
,
sequence
<
WarpRepeatN
,
BlockWarpsN
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment