Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
27a2a0a1
Unverified
Commit
27a2a0a1
authored
Jan 29, 2025
by
Max Podkorytov
Browse files
abstract score_mod from a pipeline
parent
529bda90
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
110 deletions
+31
-110
include/ck_tile/ops/fmha/kernel/fmha_flex_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_flex_fwd_kernel.hpp
+25
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp
...e/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp
+3
-55
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp
...fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp
+3
-55
No files found.
include/ck_tile/ops/fmha/kernel/fmha_flex_fwd_kernel.hpp
View file @
27a2a0a1
...
@@ -1302,6 +1302,21 @@ struct FmhaFwdKernel
...
@@ -1302,6 +1302,21 @@ struct FmhaFwdKernel
}
}
}();
}();
auto
score_mod_def
=
[](
auto
s
,
ck_tile
::
index_t
b
,
ck_tile
::
index_t
h
,
ck_tile
::
index_t
q_idx
,
ck_tile
::
index_t
v_idx
)
{
(
void
)
h
;
(
void
)
b
;
return
s
+
static_cast
<
decltype
(
s
)
>
(
q_idx
-
v_idx
);
};
auto
score_mod_arg
=
[
b
=
i_batch
,
h
=
i_nhead
,
score_mod_def
](
auto
s
,
ck_tile
::
index_t
q_idx
,
ck_tile
::
index_t
v_idx
)
{
return
score_mod_def
(
s
,
b
,
h
,
q_idx
,
v_idx
);
};
auto
o_acc_tile
=
[
&
]()
{
auto
o_acc_tile
=
[
&
]()
{
if
constexpr
(
kDoFp8StaticQuant
)
if
constexpr
(
kDoFp8StaticQuant
)
{
{
...
@@ -1318,6 +1333,7 @@ struct FmhaFwdKernel
...
@@ -1318,6 +1333,7 @@ struct FmhaFwdKernel
lse_dram_window
,
lse_dram_window
,
identity
{},
// lse_element_func
identity
{},
// lse_element_func
identity
{},
// s_acc_element_func
identity
{},
// s_acc_element_func
score_mod_arg
,
scales
{
kargs
.
scale_p
},
// p_compute_element_func
scales
{
kargs
.
scale_p
},
// p_compute_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
mask
,
mask
,
...
@@ -1329,11 +1345,20 @@ struct FmhaFwdKernel
...
@@ -1329,11 +1345,20 @@ struct FmhaFwdKernel
else
else
{
{
return
FmhaPipeline
{}(
q_dram_window
,
return
FmhaPipeline
{}(
q_dram_window
,
identity
{},
k_dram_window
,
k_dram_window
,
identity
{},
v_dram_window
,
v_dram_window
,
identity
{},
bias_dram_window
,
bias_dram_window
,
identity
{},
randval_dram_window
,
randval_dram_window
,
lse_dram_window
,
lse_dram_window
,
identity
{},
identity
{},
score_mod_arg
,
identity
{},
identity
{},
mask
,
mask
,
position_encoding
,
position_encoding
,
kargs
.
scale_s
,
kargs
.
scale_s
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp
View file @
27a2a0a1
...
@@ -120,6 +120,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -120,6 +120,7 @@ struct BlockFmhaPipelineQRKSVS
typename
BiasElementFunction
,
typename
BiasElementFunction
,
typename
LSEElementFunction
,
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
SAccElementFunction
,
typename
ScoreModFunction
,
typename
PComputeElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
typename
PositionEncoding
>
...
@@ -136,6 +137,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -136,6 +137,7 @@ struct BlockFmhaPipelineQRKSVS
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
ScoreModFunction
&
score_mod
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
FmhaMask
mask
,
...
@@ -339,16 +341,6 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -339,16 +341,6 @@ struct BlockFmhaPipelineQRKSVS
// STAGE 2, scale_s, add bias, mask, softmax
// STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores
// FlexAttention score modifier operates directly on scores
{
{
auto
score_mod
=
[](
auto
s
,
ck_tile
::
index_t
b
,
ck_tile
::
index_t
h
,
ck_tile
::
index_t
q_idx
,
ck_tile
::
index_t
v_idx
)
{
(
void
)
s
;
(
void
)
b
;
(
void
)
h
;
return
static_cast
<
decltype
(
s
)
>
(
q_idx
-
v_idx
);
};
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
...
@@ -361,10 +353,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -361,10 +353,7 @@ struct BlockFmhaPipelineQRKSVS
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
b
=
0
;
s_acc
(
i_j_idx
)
=
score_mod
(
s_acc
(
i_j_idx
),
row
,
col
);
const
auto
h
=
0
;
s_acc
(
i_j_idx
)
=
score_mod
(
s_acc
(
i_j_idx
),
b
,
h
,
row
,
col
);
});
});
});
});
}
}
...
@@ -634,47 +623,6 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -634,47 +623,6 @@ struct BlockFmhaPipelineQRKSVS
return
o_acc
;
return
o_acc
;
}
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
identity
{},
v_dram_block_window_tmp
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
mask
,
position_encoding
,
scale_s
,
smem_ptr
,
dropout
);
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp
View file @
27a2a0a1
...
@@ -138,6 +138,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -138,6 +138,7 @@ struct BlockFmhaPipelineQRKSVSAsync
typename
BiasElementFunction
,
typename
BiasElementFunction
,
typename
LSEElementFunction
,
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
SAccElementFunction
,
typename
ScoreModFunction
,
typename
PComputeElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
typename
PositionEncoding
>
...
@@ -154,6 +155,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -154,6 +155,7 @@ struct BlockFmhaPipelineQRKSVSAsync
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
ScoreModFunction
&
score_mod
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
FmhaMask
mask
,
...
@@ -409,16 +411,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -409,16 +411,6 @@ struct BlockFmhaPipelineQRKSVSAsync
// STAGE 2, scale_s, add bias, mask, softmax
// STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores
// FlexAttention score modifier operates directly on scores
{
{
auto
score_mod
=
[](
auto
s
,
ck_tile
::
index_t
b
,
ck_tile
::
index_t
h
,
ck_tile
::
index_t
q_idx
,
ck_tile
::
index_t
v_idx
)
{
(
void
)
s
;
(
void
)
b
;
(
void
)
h
;
return
static_cast
<
decltype
(
s
)
>
(
q_idx
-
v_idx
);
};
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
...
@@ -431,10 +423,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -431,10 +423,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
b
=
0
;
s_acc
(
i_j_idx
)
=
score_mod
(
s_acc
(
i_j_idx
),
row
,
col
);
const
auto
h
=
0
;
s_acc
(
i_j_idx
)
=
score_mod
(
s_acc
(
i_j_idx
),
b
,
h
,
row
,
col
);
});
});
});
});
}
}
...
@@ -770,47 +759,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -770,47 +759,6 @@ struct BlockFmhaPipelineQRKSVSAsync
return
o_acc
;
return
o_acc
;
}
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_tmp
,
identity
{},
v_dram_block_window_tmp
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
randval_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
mask
,
position_encoding
,
scale_s
,
smem_ptr
,
dropout
);
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment