Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e6c489df
Commit
e6c489df
authored
Aug 06, 2024
by
rocking
Browse files
Support unpad layout for group lse
parent
3efdb593
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
16 deletions
+16
-16
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+14
-11
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+0
-1
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+2
-4
No files found.
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
e6c489df
...
@@ -479,16 +479,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -479,16 +479,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
1
<
num_splits
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
num_splits
,
batch
,
nhead
,
max_seqlen_q
}
1
<
num_splits
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
num_splits
,
shape_batch
,
nhead
,
shape_seqlen_q
}
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
1
<
num_splits
1
<
num_splits
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
num_splits
,
batch
,
nhead
,
max_seqlen_q
,
hdim_v
}
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
num_splits
,
batch
,
nhead
,
max_seqlen_q
,
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
// self define lse data layout as [batch, nhead, max_seqlen_q]
// batch mode of lse data layout is [batch, nhead, max_seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile
::
HostTensor
<
LSEDataType
>
lse_host
(
ck_tile
::
HostTensor
<
LSEDataType
>
lse_host
(
lse
?
std
::
array
<
ck_tile
::
index_t
,
3
>
{
batch
,
nhead
,
max
_seqlen_q
}
lse
?
std
::
array
<
ck_tile
::
index_t
,
3
>
{
shape_
batch
,
nhead
,
shape
_seqlen_q
}
:
std
::
array
<
ck_tile
::
index_t
,
3
>
{
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
:
std
::
array
<
ck_tile
::
index_t
,
3
>
{
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
ODataType
>
o_host
(
ck_tile
::
HostTensor
<
ODataType
>
o_host
(
...
@@ -669,8 +671,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -669,8 +671,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
nhead_stride_bias
=
const
ck_tile
::
index_t
nhead_stride_bias
=
(
i_perm
?
0
*
shape_seqlen_q
*
shape_seqlen_k
:
0
*
shape_seqlen_k
);
(
i_perm
?
0
*
shape_seqlen_q
*
shape_seqlen_k
:
0
*
shape_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_lse
=
max
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse
=
shape
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse_acc
=
max
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse_acc
=
shape
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_o_acc
=
(
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o_acc
=
(
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
// setup batch_stride_* arguments
// setup batch_stride_* arguments
...
@@ -679,12 +681,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -679,12 +681,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
batch_stride_v
=
(
nhead_k
*
hdim_v
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_v
=
(
nhead_k
*
hdim_v
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_bias
=
(
0
*
nhead
*
shape_seqlen_q
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_bias
=
(
0
*
nhead
*
shape_seqlen_q
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
max
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
shape
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
max
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
shape
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
// setup split_stride_* arguments (only used in split-kv kernel)
// setup split_stride_* arguments (only used in split-kv kernel)
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
batch
*
nhead
*
max
_seqlen_q
);
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
batch
*
nhead
*
shape
_seqlen_q
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
batch
*
nhead
*
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
batch
*
nhead
*
max_seqlen_q
*
hdim_v
);
return
fmha_fwd_args
{
q_buf
.
GetDeviceBuffer
(),
return
fmha_fwd_args
{
q_buf
.
GetDeviceBuffer
(),
...
@@ -996,8 +998,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -996,8 +998,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
lse
)
if
(
lse
)
{
{
ck_tile
::
HostTensor
<
SMPLComputeDataType
>
lse_host_result
({
nhead
,
real_seqlen_q
});
ck_tile
::
HostTensor
<
SMPLComputeDataType
>
lse_host_result
({
nhead
,
real_seqlen_q
});
lse_host_result
.
ForEach
(
lse_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_host
(
wb
,
idx
[
0
],
idx
[
1
]);
});
self
(
idx
)
=
lse_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
);
});
cur_pass
=
ck_tile
::
check_err
(
lse_host_result
,
cur_pass
=
ck_tile
::
check_err
(
lse_host_result
,
lse_host_ref
,
lse_host_ref
,
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
e6c489df
...
@@ -185,7 +185,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -185,7 +185,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args
.
nhead_stride_randval
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
nhead_stride_o
,
args
.
batch_stride_lse
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
e6c489df
...
@@ -86,7 +86,7 @@ struct FmhaFwdKernel
...
@@ -86,7 +86,7 @@ struct FmhaFwdKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _SS_
#undef _TS_
#undef _TS_
...
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
...
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
...
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
...
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
}
if
constexpr
(
kDoFp8StaticQuant
)
if
constexpr
(
kDoFp8StaticQuant
)
{
{
...
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
...
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
}
}
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
batch_offset_lse
=
query_start
;
}
}
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment