Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7e9d2390
Commit
7e9d2390
authored
Jul 27, 2024
by
danyao12
Browse files
dq_acc stride stuff
parent
224a7b02
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
29 deletions
+70
-29
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
+1
-1
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+69
-28
No files found.
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
View file @
7e9d2390
...
@@ -11,7 +11,7 @@ COMMON_ARGS='-v=1'
...
@@ -11,7 +11,7 @@ COMMON_ARGS='-v=1'
set
-x
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
for
perm
in
0 1
;
do
for
perm
in
0 1
;
do
for
hdim
in
6
4
;
do
for
hdim
in
32 64 128 25
6
;
do
for
mode
in
0 1
;
do
for
mode
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
dbias
in
0 1
;
do
for
dbias
in
0 1
;
do
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
7e9d2390
...
@@ -137,6 +137,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -137,6 +137,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
ck_tile
::
index_t
stride_dv
;
...
@@ -145,6 +146,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -145,6 +146,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_lsed
;
};
};
...
@@ -236,6 +238,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -236,6 +238,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
ck_tile
::
index_t
batch_stride_dv
;
};
};
...
@@ -286,6 +289,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -286,6 +289,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
stride_dbias
,
...
@@ -296,6 +300,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -296,6 +300,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_k
,
...
@@ -304,6 +309,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -304,6 +309,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
batch_stride_dbias
,
...
@@ -335,6 +341,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -335,6 +341,7 @@ struct FmhaBwdDQDKDVKernel
stride_k
,
stride_k
,
stride_v
,
stride_v
,
stride_do
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dk
,
stride_dv
,
stride_dv
,
nhead_stride_q
,
nhead_stride_q
,
...
@@ -342,6 +349,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -342,6 +349,7 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_lsed
,
nhead_stride_dq_acc
,
batch_stride_lsed
},
// args for common karg
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for dbias
...
@@ -352,6 +360,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -352,6 +360,7 @@ struct FmhaBwdDQDKDVKernel
batch_stride_k
,
batch_stride_k
,
batch_stride_v
,
batch_stride_v
,
batch_stride_do
,
batch_stride_do
,
batch_stride_dq_acc
,
batch_stride_dk
,
batch_stride_dk
,
batch_stride_dv
};
batch_stride_dv
};
...
@@ -431,6 +440,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -431,6 +440,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
ck_tile
::
index_t
stride_dbias
,
...
@@ -441,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -441,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
,
...
@@ -471,6 +482,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -471,6 +482,7 @@ struct FmhaBwdDQDKDVKernel
stride_k
,
stride_k
,
stride_v
,
stride_v
,
stride_do
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dk
,
stride_dv
,
stride_dv
,
nhead_stride_q
,
nhead_stride_q
,
...
@@ -478,6 +490,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -478,6 +490,7 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_lsed
,
nhead_stride_dq_acc
,
batch_stride_lsed
},
// args for common karg
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for dbias
...
@@ -571,6 +584,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -571,6 +584,7 @@ struct FmhaBwdDQDKDVKernel
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_lsed
=
0
;
long_index_t
batch_offset_lsed
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
long_index_t
batch_offset_dk
=
0
;
long_index_t
batch_offset_dk
=
0
;
long_index_t
batch_offset_dv
=
0
;
long_index_t
batch_offset_dv
=
0
;
long_index_t
batch_offset_dbias
=
0
;
long_index_t
batch_offset_dbias
=
0
;
...
@@ -581,13 +595,14 @@ struct FmhaBwdDQDKDVKernel
...
@@ -581,13 +595,14 @@ struct FmhaBwdDQDKDVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
;
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
;
...
@@ -627,13 +642,14 @@ struct FmhaBwdDQDKDVKernel
...
@@ -627,13 +642,14 @@ struct FmhaBwdDQDKDVKernel
}
}
else
else
{
{
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
...
@@ -763,16 +779,16 @@ struct FmhaBwdDQDKDVKernel
...
@@ -763,16 +779,16 @@ struct FmhaBwdDQDKDVKernel
{
{
AccDataType
*
dq_acc_ptr
=
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_
q
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_
dq_acc
+
static_cast
<
long_index_t
>
(
i_tile_n_
)
*
kargs
.
split_stride_dq_acc
+
static_cast
<
long_index_t
>
(
i_tile_n_
)
*
kargs
.
split_stride_dq_acc
+
batch_offset_
q
;
batch_offset_
dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_
q
,
1
),
make_tuple
(
kargs
.
stride_
dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -791,7 +807,8 @@ struct FmhaBwdDQDKDVKernel
...
@@ -791,7 +807,8 @@ struct FmhaBwdDQDKDVKernel
{
{
AccDataType
*
dq_acc_ptr
=
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
batch_offset_dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
const
auto
dq_acc_dram_naive
=
...
@@ -799,7 +816,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -799,7 +816,7 @@ struct FmhaBwdDQDKDVKernel
memory_operation_enum
::
atomic_add
>
(
memory_operation_enum
::
atomic_add
>
(
dq_acc_ptr
,
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_
q
,
1
),
make_tuple
(
kargs
.
stride_
dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -1366,7 +1383,9 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1366,7 +1383,9 @@ struct FmhaBwdConvertQGradKernel
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
stride_dq
;
ck_tile
::
index_t
stride_dq
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dq
;
ck_tile
::
index_t
nhead_stride_dq
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
};
};
struct
FmhaBwdConvertQGradDeterministicKargs
struct
FmhaBwdConvertQGradDeterministicKargs
...
@@ -1381,6 +1400,7 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1381,6 +1400,7 @@ struct FmhaBwdConvertQGradKernel
FmhaBwdConvertQGradEmptyKargs
<
0
>>
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
{
ck_tile
::
index_t
batch_stride_dq
;
ck_tile
::
index_t
batch_stride_dq
;
ck_tile
::
index_t
batch_stride_dq_acc
;
};
};
struct
FmhaBwdConvertQGradGroupModeKargs
struct
FmhaBwdConvertQGradGroupModeKargs
...
@@ -1405,13 +1425,25 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1405,13 +1425,25 @@ struct FmhaBwdConvertQGradKernel
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dq
,
ck_tile
::
index_t
batch_stride_dq
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
ck_tile
::
index_t
split_stride_dq_acc
)
{
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
stride_dq
,
nhead_stride_dq
},
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
{},
batch_stride_dq
};
batch_stride_dq
,
batch_stride_dq_acc
};
if
constexpr
(
kIsDeterministic
)
if
constexpr
(
kIsDeterministic
)
{
{
...
@@ -1429,7 +1461,9 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1429,7 +1461,9 @@ struct FmhaBwdConvertQGradKernel
const
void
*
seqstart_k_ptr
,
const
void
*
seqstart_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
ck_tile
::
index_t
split_stride_dq_acc
)
{
{
Kargs
kargs
{{
dq_acc_ptr
,
Kargs
kargs
{{
dq_acc_ptr
,
...
@@ -1438,7 +1472,9 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1438,7 +1472,9 @@ struct FmhaBwdConvertQGradKernel
-
1
,
//
-
1
,
//
hdim_q
,
hdim_q
,
stride_dq
,
stride_dq
,
nhead_stride_dq
},
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
{},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
)};
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
)};
...
@@ -1477,12 +1513,14 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1477,12 +1513,14 @@ struct FmhaBwdConvertQGradKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
long_index_t
batch_offset_dq
=
0
;
long_index_t
batch_offset_dq
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
if
constexpr
(
kIsGroupMode
)
if
constexpr
(
kIsGroupMode
)
{
{
// get starting offset for each batch
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_dq
=
query_start
*
kargs
.
stride_dq
;
batch_offset_dq
=
query_start
*
kargs
.
stride_dq
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
@@ -1501,7 +1539,8 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1501,7 +1539,8 @@ struct FmhaBwdConvertQGradKernel
}
}
else
else
{
{
batch_offset_dq
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq
;
batch_offset_dq
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
}
}
// for simplicity, batch stride we just modify the pointer
// for simplicity, batch stride we just modify the pointer
...
@@ -1515,14 +1554,15 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1515,14 +1554,15 @@ struct FmhaBwdConvertQGradKernel
{
{
const
AccDataType
*
dq_acc_ptr
=
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq
)
+
batch_offset_dq
;
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
dq_acc_ptr
,
make_tuple
(
nsplits
,
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
nsplits
,
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
split_stride_dq_acc
,
kargs
.
stride_dq
,
1
),
make_tuple
(
kargs
.
split_stride_dq_acc
,
kargs
.
stride_dq
_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
return
pad_tensor_view
(
dq_acc_dram_naive
,
...
@@ -1533,12 +1573,13 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1533,12 +1573,13 @@ struct FmhaBwdConvertQGradKernel
{
{
const
AccDataType
*
dq_acc_ptr
=
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq
)
+
batch_offset_dq
;
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq
,
1
),
make_tuple
(
kargs
.
stride_dq
_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
return
pad_tensor_view
(
dq_acc_dram_naive
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment