Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f7942b99
"...resnet50_tensorflow.git" did not exist on "6ec3452c61d4bb530c9c080906ebf49dd7cbd342"
Commit
f7942b99
authored
Dec 17, 2024
by
Max Podkorytov
Browse files
qsksvs pipeline changes to mirror qrksvs
parent
d5c8a334
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
3 deletions
+50
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+2
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+4
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+43
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+1
-1
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
f7942b99
...
@@ -95,6 +95,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -95,6 +95,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
{
constexpr
std
::
array
occupancy
{
2
,
2
,
2
,
2
,
2
,
1
};
constexpr
std
::
array
occupancy
{
2
,
2
,
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
else
if
constexpr
(
kHeadDimV
<=
512
)
{
return
1
;
}
}
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
f7942b99
...
@@ -96,6 +96,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -96,6 +96,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
{
return
1
;
return
1
;
}
}
else
if
constexpr
(
kQKHeaddim
<=
512
)
{
return
1
;
}
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
f7942b99
...
@@ -12,7 +12,7 @@ namespace ck_tile {
...
@@ -12,7 +12,7 @@ namespace ck_tile {
/// NOTICE: we no-longer use this pipeline.
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
struct
[[
deprecated
]]
BlockFmhaPipelineQSKSVS
struct
BlockFmhaPipelineQSKSVS
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
...
@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
...
@@ -81,6 +99,9 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -81,6 +99,9 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static
constexpr
const
char
*
name
=
"qs"
;
static
constexpr
const
char
*
name
=
"qs"
;
// using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
using
DropoutType
=
int32_t
;
// unused
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
@@ -95,6 +116,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -95,6 +116,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
KElementFunction
,
...
@@ -106,6 +128,23 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -106,6 +128,23 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
typename
OAccElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
// const QElementFunction& q_element_func,
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
// const KElementFunction& k_element_func,
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
// const VElementFunction& v_element_func,
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
// const BiasElementFunction& bias_element_func,
// LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
// const LSEElementFunction& lse_element_func,
// const SAccElementFunction& s_acc_element_func,
// const PComputeElementFunction& p_compute_element_func,
// const OAccElementFunction& o_acc_element_func,
// FmhaMask mask,
// PositionEncoding position_encoding,
// float scale_s,
// void* smem_ptr) const
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
@@ -114,6 +153,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -114,6 +153,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const
VElementFunction
&
v_element_func
,
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
...
@@ -122,7 +162,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -122,7 +162,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
f7942b99
...
@@ -127,7 +127,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -127,7 +127,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
/// NOTICE: we no-longer use this policy.
/// NOTICE: we no-longer use this policy.
template
<
>
template
<
>
struct
[[
deprecated
]]
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
struct
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
{
{
static
constexpr
bool
QLoadOnce
=
false
;
static
constexpr
bool
QLoadOnce
=
false
;
...
...
gaoqiong
@gaoqiong
mentioned in commit
4da5206d
·
Feb 18, 2025
mentioned in commit
4da5206d
mentioned in commit 4da5206d4e128ac8a6ad0937e578cd742f520723
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment