Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4da5206d
Commit
4da5206d
authored
Jan 09, 2025
by
Po Yen Chen
Committed by
GitHub
Jan 09, 2025
Browse files
Revert "qsksvs pipeline changes to mirror qrksvs"
This reverts commit
f7942b99
.
parent
1862b27f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
3 additions
and
50 deletions
+3
-50
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+0
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+0
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+2
-43
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+1
-1
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
4da5206d
...
...
@@ -95,8 +95,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
constexpr
std
::
array
occupancy
{
2
,
2
,
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
else
if
constexpr
(
kHeadDimV
<=
512
)
{
return
1
;
}
}
}();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
4da5206d
...
...
@@ -96,10 +96,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
return
1
;
}
else
if
constexpr
(
kQKHeaddim
<=
512
)
{
return
1
;
}
}
}();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
4da5206d
...
...
@@ -12,7 +12,7 @@ namespace ck_tile {
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
struct
BlockFmhaPipelineQSKSVS
struct
[[
deprecated
]]
BlockFmhaPipelineQSKSVS
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
...
...
@@ -51,24 +51,6 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
...
...
@@ -99,9 +81,6 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
const
char
*
name
=
"qs"
;
// using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
using
DropoutType
=
int32_t
;
// unused
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
...
@@ -116,7 +95,6 @@ struct BlockFmhaPipelineQSKSVS
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
...
...
@@ -128,23 +106,6 @@ struct BlockFmhaPipelineQSKSVS
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
// const QElementFunction& q_element_func,
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
// const KElementFunction& k_element_func,
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
// const VElementFunction& v_element_func,
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
// const BiasElementFunction& bias_element_func,
// LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
// const LSEElementFunction& lse_element_func,
// const SAccElementFunction& s_acc_element_func,
// const PComputeElementFunction& p_compute_element_func,
// const OAccElementFunction& o_acc_element_func,
// FmhaMask mask,
// PositionEncoding position_encoding,
// float scale_s,
// void* smem_ptr) const
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
...
@@ -153,7 +114,6 @@ struct BlockFmhaPipelineQSKSVS
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
...
...
@@ -162,8 +122,7 @@ struct BlockFmhaPipelineQSKSVS
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
void
*
smem_ptr
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
4da5206d
...
...
@@ -127,7 +127,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
/// NOTICE: we no-longer use this policy.
template
<
>
struct
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
struct
[[
deprecated
]]
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
{
static
constexpr
bool
QLoadOnce
=
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment