Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b75c9265
Commit
b75c9265
authored
Sep 24, 2024
by
Po Yen, Chen
Browse files
Load K once in 2wave pipeline
parent
ee44cf04
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
127 additions
and
54 deletions
+127
-54
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave.hpp
+27
-48
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
+100
-6
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave.hpp
View file @
b75c9265
...
...
@@ -35,6 +35,8 @@ struct BlockFmhaPipelineQRKSVS2Wave
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
...
...
@@ -149,22 +151,23 @@ struct BlockFmhaPipelineQRKSVS2Wave
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0BlockLength
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK0BlockLength
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsStoreBlockDescriptor
<
Problem
>());
auto
k_lds_window_for_store
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0BlockLength
>
{}),
{
0
,
0
});
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
...
...
@@ -264,7 +267,7 @@ struct BlockFmhaPipelineQRKSVS2Wave
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
do
{
...
...
@@ -278,11 +281,12 @@ struct BlockFmhaPipelineQRKSVS2Wave
auto
k_block_tile
=
load_tile
(
k_dram_window
);
{
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
clear_tile
(
s_acc
);
// initialize C
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
k_block_tile
=
load_tile
(
k_dram_window
);
clear_tile
(
s_acc
);
store_tile
(
k_lds_window_for_store
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
}
auto
k_lds_window_for_load
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -296,44 +300,19 @@ struct BlockFmhaPipelineQRKSVS2Wave
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
// LDS write i + 1
k_block_tile
=
load_tile
(
k_dram_window
);
// global read i + 2
k_lds_window_for_load
);
move_tile_window
(
k_lds_window_for_load
,
{
0
,
kK0
});
});
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kM0
,
(
k0_loops
-
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
k_lds_window
);
}
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_2wave_default_policy.hpp
View file @
b75c9265
...
...
@@ -9,11 +9,105 @@
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
struct
BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
index_t
K1
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// TODO: this is used for non async copy desc. unify in the future
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsStoreBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kNPerBlock
>
{},
number
<
kKPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
+
1
)
*
kKPack
>
{},
number
<
kKPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
kKPack
>
{},
number
<
kKPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSingleSmemElementSpaceSize
()
{
// this function assume K/V can share smem
constexpr
index_t
SingleKSize
=
[
&
]()
{
return
MakeKLdsStoreBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}();
constexpr
index_t
SingleVSize
=
[
&
]()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
Banks
=
32
;
// TODO: need change based on arch
constexpr
index_t
PixelsPerRow
=
Banks
*
4
/
sizeof
(
VDataType
);
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
static_assert
(
PixelsPerRow
%
kKPack
==
0
);
constexpr
index_t
NPerRow
=
PixelsPerRow
/
kKPack
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
static_assert
(
kNPerBlock
%
NPerRow
==
0
);
static_assert
(
kKPerBlock
%
kKPack
==
0
);
return
(
kKPerBlock
/
kKPack
)
*
(
kNPerBlock
/
NPerRow
)
*
(
PixelsPerRow
+
kKPack
);
}();
return
max
(
SingleKSize
,
SingleVSize
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeKV
()
{
// TODO: assume Q is in register
// TODO: assume K/V has same data type
constexpr
index_t
single_smem_size
=
GetSingleSmemElementSpaceSize
<
Problem
>
()
*
sizeof
(
typename
Problem
::
KDataType
);
return
QXPolicy
::
template
GetSmemSizeQ
<
Problem
>()
+
single_smem_size
*
max
(
NumPrefetchK
,
NumPrefetchV
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
GetSmemSizeKV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
(
0
));
}
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment