Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
36da8a88
Commit
36da8a88
authored
Oct 15, 2024
by
Po Yen, Chen
Browse files
Async copy K tile
parent
54d2e0a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
76 deletions
+49
-76
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave.hpp
+26
-17
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
+23
-59
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave.hpp
View file @
36da8a88
...
...
@@ -60,8 +60,7 @@ struct BlockFmhaPipelineQRKSVS2Wave
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
...
...
@@ -129,7 +128,7 @@ struct BlockFmhaPipelineQRKSVS2Wave
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KElementFunction
&
k_element_func
,
[[
maybe_unused
]]
const
KElementFunction
&
k_element_func
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
...
...
@@ -164,11 +163,19 @@ struct BlockFmhaPipelineQRKSVS2Wave
"wrong!"
);
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>());
auto
k_lds_window_for_store
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0BlockLength
>
{}),
{
0
,
0
});
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds_store
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
number
<
0
>
{})),
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>(
number
<
0
>
{}).
get_lengths
(),
{
0
,
0
,
0
});
auto
k_lds_Load_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>());
auto
k_lds_load
=
make_tile_window
(
k_lds_Load_view
,
Policy
::
template
MakeKLdsLoadBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
...
...
@@ -279,15 +286,16 @@ struct BlockFmhaPipelineQRKSVS2Wave
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
auto
k_block_tile
=
load_tile
(
k_dram_window
);
k_dram_window
.
init_raw
();
// this is necessary for async_load_tile_raw()
{
constexpr
auto
k_oob_ck
=
bool_constant
<
true
>
{};
constexpr
auto
k_pre_np
=
bool_constant
<
true
>
{};
async_load_tile_raw
(
k_lds_store
,
k_dram_window
,
k_oob_ck
,
k_pre_np
);
clear_tile
(
s_acc
);
store_tile
(
k_lds_window_for_store
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
)
);
async_load_fence
();
__builtin_amdgcn_sched_barrier
(
0
);
}
auto
k_lds_window_for_load
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -302,14 +310,15 @@ struct BlockFmhaPipelineQRKSVS2Wave
}
{
block_sync_lds
();
__builtin_amdgcn_s_barrier
();
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window_for_load
);
move_tile_window
(
k_lds_window_for_load
,
{
0
,
kK0
});
get_slice_tile
(
k_lds_load
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
});
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
View file @
36da8a88
...
...
@@ -107,7 +107,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
// end copy from BlockFmhaPipelineQXCustomPolicy<true>
// start copy from BlockFmhaPipelineQXKSVSCustomPolicy
static
constexpr
bool
AsyncCopyK
=
fals
e
;
static
constexpr
bool
AsyncCopyK
=
tru
e
;
static
constexpr
bool
AsyncCopyV
=
false
;
// TODO: this not supported yet
static
constexpr
index_t
NumPrefetchK
=
1
;
...
...
@@ -255,9 +255,12 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
else
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile
// size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
// [POYENC] updated tile size
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
KPack
=
GetSmemKPackK
<
Problem
>
();
// this is for lds
constexpr
index_t
KVector
=
GetAlignmentK
<
Problem
>
();
// this is for global load
...
...
@@ -326,8 +329,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return
q_block_dstr
;
}
#if 0 // [POYENC] disabled since we are using
// MakeKLdsStoreBlockDescriptor/MakeVLdsStoreBlockDescriptor now
#if 0
// TODO: this is used for non async copy desc. unify in the future
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
...
...
@@ -352,6 +354,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc;
}
#endif
template
<
typename
Problem
,
index_t
IBuf
=
0
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
...
...
@@ -359,7 +362,9 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
// [POYENC] updated tile size
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
...
...
@@ -407,61 +412,14 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return
k_lds_block_desc_issues_warps_lanes
;
}
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{},
number<warpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
#else
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsLoadBlockDescriptor
()
{
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
// [POYENC] updated tile size
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
...
...
@@ -510,8 +468,9 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return
k_lds_block_desc
;
}
#endif
#if 0 // [POYENC] disabled since we are using
// MakeKLdsStoreBlockDescriptor/MakeVLdsStoreBlockDescriptor now
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
...
...
@@ -640,7 +599,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
else
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile
// size
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
// [POYENC] updated tile size
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NumWarps
=
Problem
::
BlockFmhaShape
::
NumWarps
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
...
...
@@ -868,6 +830,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
}
// end copy from BlockFmhaPipelineQXKSVSCustomPolicy
#if 0
// TODO: this is used for non async copy desc. unify in the future
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor()
...
...
@@ -893,6 +856,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc;
}
#endif
// 3d + padding
template
<
typename
Problem
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment