Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4da5206d
Commit
4da5206d
authored
Jan 09, 2025
by
Po Yen Chen
Committed by
GitHub
Jan 09, 2025
Browse files
Revert "qsksvs pipeline changes to mirror qrksvs"
This reverts commit
f7942b99
.
parent
1862b27f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
3 additions
and
50 deletions
+3
-50
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+0
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+0
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+2
-43
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+1
-1
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
4da5206d
...
@@ -95,8 +95,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -95,8 +95,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
{
constexpr
std
::
array
occupancy
{
2
,
2
,
2
,
2
,
2
,
1
};
constexpr
std
::
array
occupancy
{
2
,
2
,
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
else
if
constexpr
(
kHeadDimV
<=
512
)
{
return
1
;
}
}
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
4da5206d
...
@@ -96,10 +96,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -96,10 +96,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
{
return
1
;
return
1
;
}
}
else
if
constexpr
(
kQKHeaddim
<=
512
)
{
return
1
;
}
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
4da5206d
...
@@ -12,7 +12,7 @@ namespace ck_tile {
...
@@ -12,7 +12,7 @@ namespace ck_tile {
/// NOTICE: we no-longer use this pipeline.
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
struct
BlockFmhaPipelineQSKSVS
struct
[[
deprecated
]]
BlockFmhaPipelineQSKSVS
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
...
@@ -51,24 +51,6 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -51,24 +51,6 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
...
@@ -99,9 +81,6 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -99,9 +81,6 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
const
char
*
name
=
"qs"
;
static
constexpr
const
char
*
name
=
"qs"
;
// using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
using
DropoutType
=
int32_t
;
// unused
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
@@ -116,7 +95,6 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -116,7 +95,6 @@ struct BlockFmhaPipelineQSKSVS
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
KElementFunction
,
...
@@ -128,23 +106,6 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -128,23 +106,6 @@ struct BlockFmhaPipelineQSKSVS
typename
OAccElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
// const QElementFunction& q_element_func,
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
// const KElementFunction& k_element_func,
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
// const VElementFunction& v_element_func,
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
// const BiasElementFunction& bias_element_func,
// LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
// const LSEElementFunction& lse_element_func,
// const SAccElementFunction& s_acc_element_func,
// const PComputeElementFunction& p_compute_element_func,
// const OAccElementFunction& o_acc_element_func,
// FmhaMask mask,
// PositionEncoding position_encoding,
// float scale_s,
// void* smem_ptr) const
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
@@ -153,7 +114,6 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -153,7 +114,6 @@ struct BlockFmhaPipelineQSKSVS
const
VElementFunction
&
v_element_func
,
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
...
@@ -162,8 +122,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -162,8 +122,7 @@ struct BlockFmhaPipelineQSKSVS
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
)
const
DropoutType
&
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
4da5206d
...
@@ -127,7 +127,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -127,7 +127,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
/// NOTICE: we no-longer use this policy.
/// NOTICE: we no-longer use this policy.
template
<
>
template
<
>
struct
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
struct
[[
deprecated
]]
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
{
{
static
constexpr
bool
QLoadOnce
=
false
;
static
constexpr
bool
QLoadOnce
=
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment