Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bfc997a7
Commit
bfc997a7
authored
Dec 18, 2024
by
Max Podkorytov
Browse files
update qsksvs pipeline
parent
f7942b99
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
5 deletions
+46
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+46
-5
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
bfc997a7
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -99,8 +100,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -99,8 +100,7 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
const
char
*
name
=
"qs"
;
static
constexpr
const
char
*
name
=
"qs"
;
// using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
using
DropoutType
=
int32_t
;
// unused
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
...
@@ -267,7 +267,8 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -267,7 +267,8 @@ struct BlockFmhaPipelineQSKSVS
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
// Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto
v_dram_window
=
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
@@ -620,10 +621,46 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -620,10 +621,46 @@ struct BlockFmhaPipelineQSKSVS
return
o_acc
;
return
o_acc
;
}
}
// template <typename QDramBlockWindowTmp,
// typename KDramBlockWindowTmp,
// typename VDramBlockWindowTmp,
// typename BiasDramBlockWindowTmp,
// typename LSEDramBlockWindowTmp,
// typename PositionEncoding>
// CK_TILE_HOST_DEVICE auto
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
// LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
// FmhaMask mask,
// PositionEncoding position_encoding,
// float scale_s,
// void* smem_ptr) const
// {
// return operator()(q_dram_block_window_tmp,
// identity{},
// k_dram_block_window_tmp,
// identity{},
// v_dram_block_window_tmp,
// identity{},
// bias_dram_block_window_tmp,
// identity{},
// lse_dram_block_window_tmp,
// identity{},
// identity{},
// identity{},
// identity{},
// mask,
// position_encoding,
// scale_s,
// smem_ptr);
// }
template
<
typename
QDramBlockWindowTmp
,
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
...
@@ -631,11 +668,13 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -631,11 +668,13 @@ struct BlockFmhaPipelineQSKSVS
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
...
@@ -645,6 +684,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -645,6 +684,7 @@ struct BlockFmhaPipelineQSKSVS
identity
{},
identity
{},
bias_dram_block_window_tmp
,
bias_dram_block_window_tmp
,
identity
{},
identity
{},
randval_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
...
@@ -653,7 +693,8 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -653,7 +693,8 @@ struct BlockFmhaPipelineQSKSVS
mask
,
mask
,
position_encoding
,
position_encoding
,
scale_s
,
scale_s
,
smem_ptr
);
smem_ptr
,
dropout
);
}
}
};
};
...
...
gaoqiong
@gaoqiong
mentioned in commit
1862b27f
·
Feb 18, 2025
mentioned in commit
1862b27f
mentioned in commit 1862b27f349e55c830ce196f2b1f5686573c8fc0
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment