Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3efdb593
Commit
3efdb593
authored
Aug 06, 2024
by
danyao12
Browse files
unpadded lse&d for group mode
parent
25db1339
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
35 deletions
+28
-35
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+5
-5
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+1
-3
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+22
-27
No files found.
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
3efdb593
...
@@ -287,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -287,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
ODataType
>
o_host
(
ck_tile
::
HostTensor
<
ODataType
>
o_host
(
get_lengths
(
o_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_v
));
get_lengths
(
o_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_v
));
ck_tile
::
HostTensor
<
LSEDataType
>
lse_host
(
ck_tile
::
HostTensor
<
LSEDataType
>
lse_host
(
std
::
array
<
ck_tile
::
index_t
,
3
>
{
batch
,
nhead
,
max
_seqlen_q
});
std
::
array
<
ck_tile
::
index_t
,
3
>
{
shape_
batch
,
nhead
,
shape
_seqlen_q
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
(
ck_tile
::
HostTensor
<
DDataType
>
d_host
(
std
::
array
<
ck_tile
::
index_t
,
3
>
{
batch
,
nhead
,
max
_seqlen_q
});
std
::
array
<
ck_tile
::
index_t
,
3
>
{
shape_
batch
,
nhead
,
shape
_seqlen_q
});
ck_tile
::
HostTensor
<
RandValOutputDataType
>
randval_host
(
ck_tile
::
HostTensor
<
RandValOutputDataType
>
randval_host
(
p_drop
>
0
?
get_lengths
(
true
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
p_drop
>
0
?
get_lengths
(
true
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
...
@@ -441,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -441,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_do
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_do
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_lsed
=
max
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lsed
=
shape
_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_dbias
=
const
ck_tile
::
index_t
nhead_stride_dbias
=
(
i_perm
?
shape_seqlen_q
*
max_seqlen_k
:
max_seqlen_k
);
(
i_perm
?
shape_seqlen_q
*
max_seqlen_k
:
max_seqlen_k
);
// setup batch_stride_* arguments
// setup batch_stride_* arguments
...
@@ -452,7 +452,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -452,7 +452,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_do
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_do
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_lsed
=
(
nhead
*
max
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lsed
=
(
nhead
*
shape
_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_dk
=
(
nhead
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_dk
=
(
nhead
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_dv
=
(
nhead
*
shape_seqlen_k
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_dv
=
(
nhead
*
shape_seqlen_k
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_dbias
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_dbias
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
...
@@ -749,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -749,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
o_perm
)
o_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
o_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
])
=
self
(
idx
);
});
if
(
o_perm
)
o_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
o_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
])
=
self
(
idx
);
});
else
o_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
o_host
(
b
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
])
=
self
(
idx
);
});
else
o_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
o_host
(
b
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
])
=
self
(
idx
);
});
lse_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_host
(
w
b
,
idx
[
0
],
idx
[
1
])
=
self
(
idx
);
});
lse_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
)
=
self
(
idx
);
});
// clang-format on
// clang-format on
q_host_refs
.
push_back
(
q_host_ref
);
q_host_refs
.
push_back
(
q_host_ref
);
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
3efdb593
...
@@ -187,7 +187,6 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -187,7 +187,6 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
nhead_stride_dk
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_lsed
,
args
.
split_stride_dq_acc
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
...
@@ -278,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -278,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args
.
stride_o
,
args
.
stride_o
,
args
.
nhead_stride_do
,
args
.
nhead_stride_do
,
args
.
nhead_stride_o
,
args
.
nhead_stride_o
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_lsed
);
args
.
batch_stride_lsed
);
}
}
else
else
{
// create batch mode kernel arguments
{
// create batch mode kernel arguments
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
3efdb593
...
@@ -155,8 +155,6 @@ struct FmhaBwdDQDKDVKernel
...
@@ -155,8 +155,6 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dk
;
ck_tile
::
index_t
nhead_stride_dk
;
ck_tile
::
index_t
nhead_stride_dv
;
ck_tile
::
index_t
nhead_stride_dv
;
ck_tile
::
index_t
batch_stride_lsed
;
};
};
struct
FmhaBwdCommonBiasKargs
struct
FmhaBwdCommonBiasKargs
...
@@ -246,6 +244,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -246,6 +244,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
ck_tile
::
index_t
batch_stride_dv
;
...
@@ -361,17 +360,17 @@ struct FmhaBwdDQDKDVKernel
...
@@ -361,17 +360,17 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_lsed
,
nhead_stride_lsed
,
nhead_stride_dq_acc
,
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dk
,
nhead_stride_dv
,
nhead_stride_dv
},
// args for common karg
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for dropout
{},
// placeholder for deterministic
{},
// placeholder for deterministic
batch_stride_q
,
batch_stride_q
,
batch_stride_k
,
batch_stride_k
,
batch_stride_v
,
batch_stride_v
,
batch_stride_do
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_dq_acc
,
batch_stride_dq_acc
,
batch_stride_dk
,
batch_stride_dk
,
batch_stride_dv
};
batch_stride_dv
};
...
@@ -467,7 +466,6 @@ struct FmhaBwdDQDKDVKernel
...
@@ -467,7 +466,6 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
...
@@ -506,13 +504,12 @@ struct FmhaBwdDQDKDVKernel
...
@@ -506,13 +504,12 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_lsed
,
nhead_stride_lsed
,
nhead_stride_dq_acc
,
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dk
,
nhead_stride_dv
,
nhead_stride_dv
},
// args for common karg
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for dropout
{},
// placeholder for deterministic
{},
// placeholder for deterministic
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
...
@@ -615,7 +612,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -615,7 +612,7 @@ struct FmhaBwdDQDKDVKernel
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_lsed
=
query_start
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
...
@@ -1142,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1142,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_d
;
ck_tile
::
index_t
nhead_stride_d
;
ck_tile
::
index_t
batch_stride_d
;
};
};
struct
FmhaBwdOGradDotOBatchModeKargs
:
FmhaBwdOGradDotOCommonKargs
struct
FmhaBwdOGradDotOBatchModeKargs
:
FmhaBwdOGradDotOCommonKargs
{
{
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_d
;
};
};
struct
FmhaBwdOGradDotOGroupModeKargs
:
FmhaBwdOGradDotOCommonKargs
struct
FmhaBwdOGradDotOGroupModeKargs
:
FmhaBwdOGradDotOCommonKargs
...
@@ -1186,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1186,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
stride_o
,
nhead_stride_do
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_o
,
nhead_stride_d
,
nhead_stride_d
},
batch_stride_d
},
batch_stride_do
,
batch_stride_do
,
batch_stride_o
};
batch_stride_o
,
batch_stride_d
};
return
kargs
;
return
kargs
;
}
}
...
@@ -1206,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1206,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_d
,
ck_tile
::
index_t
nhead_stride_d
)
ck_tile
::
index_t
batch_stride_d
)
{
{
Kargs
kargs
{{
o_ptr
,
Kargs
kargs
{{
o_ptr
,
do_ptr
,
do_ptr
,
...
@@ -1219,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1219,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
stride_o
,
nhead_stride_do
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_o
,
nhead_stride_d
,
nhead_stride_d
},
batch_stride_d
},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
return
kargs
;
return
kargs
;
...
@@ -1263,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
...
@@ -1263,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_d
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_d
;
batch_offset_d
=
query_start
;
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment