Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f266fc72
Commit
f266fc72
authored
Jan 03, 2023
by
Tri Dao
Browse files
[Gen, FT] Use tlength instead of params.timestep for rotary
parent
a01d1213
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
...attention/decoder_masked_multihead_attention_template.hpp
+4
-4
No files found.
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
View file @
f266fc72
...
...
@@ -1082,10 +1082,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const
int
padd_len
=
(
params
.
total_padding_tokens
==
nullptr
)
?
0
:
params
.
total_padding_tokens
[
bi
];
if
(
params
.
rotary_embedding_dim
>
0
&&
!
params
.
neox_rotary_style
)
{
if
(
handle_kv
)
{
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
params
.
timestep
-
padd_len
);
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
);
}
else
{
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
params
.
timestep
-
padd_len
);
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
);
}
}
else
if
(
params
.
rotary_embedding_dim
>
0
&&
params
.
neox_rotary_style
)
{
...
...
@@ -1120,13 +1120,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
mmha
::
vec_from_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
mmha
::
apply_rotary_embedding
(
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
params
.
timestep
-
padd_len
);
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
);
mmha
::
write_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
}
else
{
mmha
::
apply_rotary_embedding
(
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
params
.
timestep
);
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
);
}
mmha
::
write_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment